add antithetic sampling for xNES

This commit is contained in:
Connor Olding 2018-06-10 16:33:38 +02:00
parent 695730335c
commit 0100934ac4
3 changed files with 36 additions and 8 deletions

View file

@ -98,8 +98,6 @@ assert(not cfg.ars_lips or cfg.unperturbed_trial,
"cfg.unperturbed_trial must be true to use cfg.ars_lips") "cfg.unperturbed_trial must be true to use cfg.ars_lips")
assert(not cfg.ars_lips or cfg.negate_trials, assert(not cfg.ars_lips or cfg.negate_trials,
"cfg.negate_trials must be true to use cfg.ars_lips") "cfg.negate_trials must be true to use cfg.ars_lips")
assert(not cfg.es == 'xnes' or not cfg.negate_trials,
"cfg.negate_trials is not yet compatible with xNES")
assert(not cfg.adamant, assert(not cfg.adamant,
"cfg.adamant not yet re-implemented") "cfg.adamant not yet re-implemented")

View file

@ -452,10 +452,8 @@ local function init()
if cfg.es == 'xnes' then if cfg.es == 'xnes' then
-- if you get an out of memory error, you can't use xNES. sorry! -- if you get an out of memory error, you can't use xNES. sorry!
-- maybe there'll be a patch for FCEUX in the future. -- maybe there'll be a patch for FCEUX in the future.
local trials = cfg.epoch_trials es = xnes.Xnes(network.n_param, cfg.epoch_trials,
if cfg.negate_trials then trials = trials * 2 end cfg.learning_rate, cfg.deviation, cfg.negate_trials)
es = xnes.Xnes(network.n_param, trials, cfg.learning_rate,
cfg.deviation, cfg.negate_trials)
elseif cfg.es == 'ars' then elseif cfg.es == 'ars' then
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials, es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
cfg.learning_rate, cfg.deviation, cfg.negate_trials) cfg.learning_rate, cfg.deviation, cfg.negate_trials)

View file

@ -75,12 +75,15 @@ local function make_covars(dims, sigma, out)
return covars return covars
end end
function Xnes:init(dims, popsize, learning_rate, sigma) function Xnes:init(dims, popsize, learning_rate, sigma, antithetic)
-- heuristic borrowed from CMA-ES: -- heuristic borrowed from CMA-ES:
self.dims = dims self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims))) self.popsize = popsize or 4 + (3 * floor(log(dims)))
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
self.sigma = sigma or 1 self.sigma = sigma or 1
self.antithetic = antithetic and true or false
if self.antithetic then self.popsize = self.popsize * 2 end
self.utility = make_utility(self.popsize) self.utility = make_utility(self.popsize)
@ -119,6 +122,25 @@ function Xnes:ask_once(asked, noise)
return asked, noise return asked, noise
end end
function Xnes:ask_twice(asked0, asked1, noise0, noise1)
asked0 = asked0 or zeros(self.dims)
asked1 = asked1 or zeros(self.dims)
noise0 = noise0 or {}
noise1 = noise1 or {}
for i=1, self.dims do noise0[i] = normal() end
noise0.shape = {#noise0}
dot_mv(self.covars, noise0, asked0)
for i, v in ipairs(asked0) do
asked0[i] = self.mean[i] + self.sigma * v
asked1[i] = self.mean[i] - self.sigma * v
end
for i, v in ipairs(noise0) do noise1[i] = -v end
return asked0, asked1, noise0, noise1
end
function Xnes:ask(asked, noise) function Xnes:ask(asked, noise)
-- return a list of parameters for the user to score, -- return a list of parameters for the user to score,
-- and later pass to :tell(). -- and later pass to :tell().
@ -130,7 +152,17 @@ function Xnes:ask(asked, noise)
noise = {} noise = {}
for i=1, self.popsize do noise[i] = zeros(self.dims) end for i=1, self.popsize do noise[i] = zeros(self.dims) end
end end
for i=1, self.popsize do self:ask_once(asked[i], noise[i]) end
if self.antithetic then
for i=1, self.popsize do
self:ask_twice(asked[i+0], asked[i+1], noise[i+0], noise[i+1])
end
else
for i=1, self.popsize do
self:ask_once(asked[i], noise[i])
end
end
self.noise = noise self.noise = noise
return asked, noise return asked, noise
end end