smbot/es_test.lua
2019-03-11 07:15:41 +01:00

164 lines
4.6 KiB
Lua

local floor = math.floor
local insert = table.insert
local ipairs = ipairs
local log = math.log
local max = math.max
local print = print
local ars = require("ars")
local snes = require("snes")
local xnes = require("xnes")
local guided = require("guided")
-- try it all out on a dummy problem.
local function typeof(t) return getmetatable(t).__index end
local function square(x) return x * x end
-- this function's global minimum is arange(dims) + 1.
-- xNES should be able to find it almost exactly.
local function spherical(x)
local sum = 0
--for i, v in ipairs(x) do sum = sum + square(v - i) end
for i, v in ipairs(x) do sum = sum + square(v - i / #x) end
-- we need to negate this to turn it into a maximization problem.
return -sum
end
-- i'm just copying settings from hardmaru's simple_es_example.ipynb.
local iterations = 3000 --4000
local dims, popsize
if false then
dims = 100
popsize = dims + 1
else
dims = 30
popsize = 99
end
local sigma_init = 0.5
--local es = xnes.Xnes(dims, popsize, 0.1, sigma_init)
local es = snes.Snes(dims, popsize, 0.1, sigma_init)
--local es = ars.Ars(dims, floor(popsize / 2), floor(popsize / 2), 1.0, sigma_init, true)
--local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true)
--local es = guided.Guided(dims, popsize, popsize, 1.0, sigma_init, 0.5)
es.min_refresh = 1.0 -- FIXME: needs a better interface.
if typeof(es) == xnes.Xnes
or typeof(es) == snes.Snes
then
-- use IGO recommendations
local pop5 = max(1, floor(es.popsize / 5))
local sum = 0
for i=1, es.popsize do
local maybe = i < pop5 and 1 or 0
es.utility[i] = maybe
sum = sum + maybe
end
--for i, v in ipairs(es.utility) do es.utility[i] = v / sum end
local util = require "util"
util.normalize_sums(es.utility)
es.param_rate = 0.2 --39
es.sigma_rate = 0.05 --39
es.covar_rate = 0.1 --39
es.adaptive = false
end
if false then -- TODO: delete me
local nn = require("nn")
local util = require("util")
local insert = table.insert
local scored = nn.arange(10)
local indices = ars.collect_best_indices(scored, 3, true)
for i, ind in ipairs(indices) do
print(ind, ":", scored[ind * 2 - 1], scored[ind * 2 - 0])
end
local top_rewards = {}
for _, ind in ipairs(indices) do
insert(top_rewards, scored[ind * 2 - 1])
insert(top_rewards, scored[ind * 2 - 0])
end
-- this shouldn't make a difference to the final print:
top_rewards = util.normalize_sums(top_rewards)
print(nn.pp(top_rewards))
local _, reward_dev = util.calc_mean_dev(top_rewards)
print(reward_dev)
for i, ind in ipairs(indices) do
local pos = top_rewards[i * 2 - 1]
local neg = top_rewards[i * 2 - 0]
local reward = pos - neg
reward = reward / reward_dev
print(reward)
end
do return end
end
local asked = nil -- for caching purposes.
local noise = nil -- for caching purposes.
local current_cost = spherical(es:params())
local past_grads = {}
local pgi = 0
local pgn = 10
for i=1, iterations do
if typeof(es) == snes.Snes and es.min_refresh ~= 1 then
asked, noise = es:ask_mix()
elseif typeof(es) == ars.Ars then
asked, noise = es:ask()
elseif typeof(es) == guided.Guided then
asked, noise = es:ask(past_grads)
else
asked, noise = es:ask(asked, noise)
end
local scores = {}
for i, v in ipairs(asked) do
scores[i] = spherical(v)
end
if typeof(es) == ars.Ars then
es:tell(scores)--, current_cost) -- use lips
elseif typeof(es) == guided.Guided then
local step = es:tell(scores)
for _, v in ipairs(step) do
past_grads[pgi + 1] = v
pgi = (pgi + 1) % (pgn * #step)
end
past_grads.shape = {floor(#past_grads / #step), #step}
else
es:tell(scores)
end
current_cost = spherical(es:params())
if i % 100 == 0 then
local sigma = es.sigma
if typeof(es) == snes.Snes then
sigma = 0
for i, v in ipairs(es.std) do sigma = sigma + v end
sigma = sigma / #es.std
end
local inconvergence = sigma / sigma_init
local fmt = "fitness at iteration %i: %.4f (%.4f)"
print(fmt:format(i, current_cost, log(inconvergence) / log(10)))
end
end
-- note: this metric doesn't include the "fitness at iteration" evaluations,
-- because those aren't actually used to step towards the optimum.
print(("optimized in %i function evaluations"):format(es.evals))
local s = ''
for i, v in ipairs(es:params()) do
s = s..("%.8f"):format(v)
if i ~= es.dim then s = s..', ' end
end
print(s)