smbot/es_test.lua
2018-06-30 20:13:54 +02:00

110 lines
3.4 KiB
Lua

local floor = math.floor
local ipairs = ipairs
local log = math.log
local print = print
local ars = require("ars")
local snes = require("snes")
local xnes = require("xnes")
-- try it all out on a dummy problem.
local function square(x) return x * x end
-- this function's global minimum is arange(dims) + 1.
-- xNES should be able to find it almost exactly.
local function spherical(x)
local sum = 0
for i, v in ipairs(x) do sum = sum + square(v - i) end
-- we need to negate this to turn it into a maximization problem.
return -sum
end
-- i'm just copying settings from hardmaru's simple_es_example.ipynb.
local iterations = 3000 --4000
local dims = 100
local popsize = dims + 1
local sigma_init = 0.5
--local es = xnes.Xnes(dims, popsize, 0.1, sigma_init)
--local es = snes.Snes(dims, popsize, 0.1, sigma_init)
--local es = ars.Ars(dims, floor(popsize / 2), floor(popsize / 2), 1.0, sigma_init, true)
local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true)
es.min_refresh = 0.7 -- FIXME: needs a better interface.
if false then -- TODO: delete me
local nn = require("nn")
local util = require("util")
local insert = table.insert
local scored = nn.arange(10)
local indices = ars.collect_best_indices(scored, 3, true)
for i, ind in ipairs(indices) do
print(ind, ":", scored[ind * 2 - 1], scored[ind * 2 - 0])
end
local top_rewards = {}
for _, ind in ipairs(indices) do
insert(top_rewards, scored[ind * 2 - 1])
insert(top_rewards, scored[ind * 2 - 0])
end
-- this shouldn't make a difference to the final print:
top_rewards = util.normalize_sums(top_rewards)
print(nn.pp(top_rewards))
local _, reward_dev = util.calc_mean_dev(top_rewards)
print(reward_dev)
for i, ind in ipairs(indices) do
local pos = top_rewards[i * 2 - 1]
local neg = top_rewards[i * 2 - 0]
local reward = pos - neg
reward = reward / reward_dev
print(reward)
end
do return end
end
local asked = nil -- for caching purposes.
local noise = nil -- for caching purposes.
local current_cost = spherical(es:params())
for i=1, iterations do
if getmetatable(es).__index == snes.Snes then
asked, noise = es:ask_mix()
elseif getmetatable(es).__index == ars.Ars then
asked, noise = es:ask()
else
asked, noise = es:ask(asked, noise)
end
local scores = {}
for i, v in ipairs(asked) do
scores[i] = spherical(v)
end
if getmetatable(es).__index == ars.Ars then
es:tell(scores)--, current_cost) -- use lips
else
es:tell(scores)
end
current_cost = spherical(es:params())
--if i % 100 == 0 then
if i % 100 == 0 then
local sigma = es.sigma
if getmetatable(es).__index == snes.Snes then
sigma = 0
for i, v in ipairs(es.std) do sigma = sigma + v end
sigma = sigma / #es.std
end
local inconvergence = sigma / sigma_init
local fmt = "fitness at iteration %i: %.4f (%.4f)"
print(fmt:format(i, current_cost, log(inconvergence) / log(10)))
end
end
-- note: this metric doesn't include the "fitness at iteration" evaluations,
-- because those aren't actually used to step towards the optimum.
print(("optimized in %i function evaluations"):format(es.evals))
local s = ''
for i, v in ipairs(es:params()) do
s = s..("%.8f"):format(v)
if i ~= es.dim then s = s..', ' end
end
print(s)