154 lines
4.6 KiB
Lua
154 lines
4.6 KiB
Lua
local floor = math.floor
|
|
local insert = table.insert
|
|
local ipairs = ipairs
|
|
local log = math.log
|
|
local max = math.max
|
|
local print = print
|
|
|
|
local ars = require("ars")
|
|
local snes = require("snes")
|
|
local xnes = require("xnes")
|
|
local guided = require("guided")
|
|
|
|
-- try it all out on a dummy problem.
|
|
|
|
local function typeof(t) return getmetatable(t).__index end
|
|
|
|
local function square(x) return x * x end
|
|
|
|
-- this function's global minimum is arange(dims) + 1.
|
|
-- xNES should be able to find it almost exactly.
|
|
local function spherical(x)
|
|
local sum = 0
|
|
--for i, v in ipairs(x) do sum = sum + square(v - i) end
|
|
for i, v in ipairs(x) do sum = sum + square(v - i / #x) end
|
|
-- we need to negate this to turn it into a maximization problem.
|
|
return -sum
|
|
end
|
|
|
|
-- i'm just copying settings from hardmaru's simple_es_example.ipynb.
|
|
local iterations = 3000 --4000
|
|
local dims = 100
|
|
local popsize = dims + 1
|
|
local sigma_init = 0.5
|
|
--local es = xnes.Xnes(dims, popsize, 0.1, sigma_init)
|
|
local es = snes.Snes(dims, popsize, 0.1, sigma_init)
|
|
--local es = ars.Ars(dims, floor(popsize / 2), floor(popsize / 2), 1.0, sigma_init, true)
|
|
--local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true)
|
|
--local es = guided.Guided(dims, popsize, popsize, 1.0, sigma_init, 0.5)
|
|
es.min_refresh = 0.5 -- FIXME: needs a better interface.
|
|
|
|
if typeof(es) == xnes.Xnes
|
|
or typeof(es) == snes.Snes
|
|
then
|
|
-- use IGO recommendations
|
|
local pop5 = max(1, floor(es.popsize / 5))
|
|
|
|
local sum = 0
|
|
for i=1, es.popsize do
|
|
local maybe = i < pop5 and 1 or 0
|
|
es.utility[i] = maybe
|
|
sum = sum + maybe
|
|
end
|
|
--for i, v in ipairs(es.utility) do es.utility[i] = v / sum end
|
|
|
|
local util = require "util"
|
|
util.normalize_sums(es.utility)
|
|
|
|
es.param_rate = 0.39
|
|
es.sigma_rate = 0.39
|
|
es.covar_rate = 0.39
|
|
es.adaptive = false
|
|
end
|
|
|
|
if false then -- TODO: delete me
|
|
local nn = require("nn")
|
|
local util = require("util")
|
|
local insert = table.insert
|
|
local scored = nn.arange(10)
|
|
local indices = ars.collect_best_indices(scored, 3, true)
|
|
for i, ind in ipairs(indices) do
|
|
print(ind, ":", scored[ind * 2 - 1], scored[ind * 2 - 0])
|
|
end
|
|
local top_rewards = {}
|
|
for _, ind in ipairs(indices) do
|
|
insert(top_rewards, scored[ind * 2 - 1])
|
|
insert(top_rewards, scored[ind * 2 - 0])
|
|
end
|
|
-- this shouldn't make a difference to the final print:
|
|
top_rewards = util.normalize_sums(top_rewards)
|
|
print(nn.pp(top_rewards))
|
|
local _, reward_dev = util.calc_mean_dev(top_rewards)
|
|
print(reward_dev)
|
|
for i, ind in ipairs(indices) do
|
|
local pos = top_rewards[i * 2 - 1]
|
|
local neg = top_rewards[i * 2 - 0]
|
|
local reward = pos - neg
|
|
reward = reward / reward_dev
|
|
print(reward)
|
|
end
|
|
do return end
|
|
end
|
|
|
|
local asked = nil -- for caching purposes.
|
|
local noise = nil -- for caching purposes.
|
|
local current_cost = spherical(es:params())
|
|
|
|
local past_grads = {}
|
|
local pgi = 0
|
|
local pgn = 10
|
|
|
|
for i=1, iterations do
|
|
if typeof(es) == snes.Snes and es.min_refresh ~= 1 then
|
|
asked, noise = es:ask_mix()
|
|
elseif typeof(es) == ars.Ars then
|
|
asked, noise = es:ask()
|
|
elseif typeof(es) == guided.Guided then
|
|
asked, noise = es:ask(past_grads)
|
|
else
|
|
asked, noise = es:ask(asked, noise)
|
|
end
|
|
|
|
local scores = {}
|
|
for i, v in ipairs(asked) do
|
|
scores[i] = spherical(v)
|
|
end
|
|
|
|
if typeof(es) == ars.Ars then
|
|
es:tell(scores)--, current_cost) -- use lips
|
|
elseif typeof(es) == guided.Guided then
|
|
local step = es:tell(scores)
|
|
|
|
for _, v in ipairs(step) do
|
|
past_grads[pgi + 1] = v
|
|
pgi = (pgi + 1) % (pgn * #step)
|
|
end
|
|
past_grads.shape = {floor(#past_grads / #step), #step}
|
|
else
|
|
es:tell(scores)
|
|
end
|
|
|
|
current_cost = spherical(es:params())
|
|
if i % 100 == 0 then
|
|
local sigma = es.sigma
|
|
if typeof(es) == snes.Snes then
|
|
sigma = 0
|
|
for i, v in ipairs(es.std) do sigma = sigma + v end
|
|
sigma = sigma / #es.std
|
|
end
|
|
local inconvergence = sigma / sigma_init
|
|
local fmt = "fitness at iteration %i: %.4f (%.4f)"
|
|
print(fmt:format(i, current_cost, log(inconvergence) / log(10)))
|
|
end
|
|
end
|
|
|
|
-- note: this metric doesn't include the "fitness at iteration" evaluations,
|
|
-- because those aren't actually used to step towards the optimum.
|
|
print(("optimized in %i function evaluations"):format(es.evals))
|
|
|
|
local s = ''
|
|
for i, v in ipairs(es:params()) do
|
|
s = s..("%.8f"):format(v)
|
|
if i ~= es.dim then s = s..', ' end
|
|
end
|
|
print(s)
|