add antithetic sampling for xNES
This commit is contained in:
parent
695730335c
commit
0100934ac4
3 changed files with 36 additions and 8 deletions
|
@ -98,8 +98,6 @@ assert(not cfg.ars_lips or cfg.unperturbed_trial,
|
|||
"cfg.unperturbed_trial must be true to use cfg.ars_lips")
|
||||
assert(not cfg.ars_lips or cfg.negate_trials,
|
||||
"cfg.negate_trials must be true to use cfg.ars_lips")
|
||||
assert(not cfg.es == 'xnes' or not cfg.negate_trials,
|
||||
"cfg.negate_trials is not yet compatible with xNES")
|
||||
|
||||
assert(not cfg.adamant,
|
||||
"cfg.adamant not yet re-implemented")
|
||||
|
|
6
main.lua
6
main.lua
|
@ -452,10 +452,8 @@ local function init()
|
|||
if cfg.es == 'xnes' then
|
||||
-- if you get an out of memory error, you can't use xNES. sorry!
|
||||
-- maybe there'll be a patch for FCEUX in the future.
|
||||
local trials = cfg.epoch_trials
|
||||
if cfg.negate_trials then trials = trials * 2 end
|
||||
es = xnes.Xnes(network.n_param, trials, cfg.learning_rate,
|
||||
cfg.deviation, cfg.negate_trials)
|
||||
es = xnes.Xnes(network.n_param, cfg.epoch_trials,
|
||||
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
||||
elseif cfg.es == 'ars' then
|
||||
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
||||
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
||||
|
|
36
xnes.lua
36
xnes.lua
|
@ -75,12 +75,15 @@ local function make_covars(dims, sigma, out)
|
|||
return covars
|
||||
end
|
||||
|
||||
function Xnes:init(dims, popsize, learning_rate, sigma)
|
||||
function Xnes:init(dims, popsize, learning_rate, sigma, antithetic)
|
||||
-- heuristic borrowed from CMA-ES:
|
||||
self.dims = dims
|
||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||
self.sigma = sigma or 1
|
||||
self.antithetic = antithetic and true or false
|
||||
|
||||
if self.antithetic then self.popsize = self.popsize * 2 end
|
||||
|
||||
self.utility = make_utility(self.popsize)
|
||||
|
||||
|
@ -119,6 +122,25 @@ function Xnes:ask_once(asked, noise)
|
|||
return asked, noise
|
||||
end
|
||||
|
||||
function Xnes:ask_twice(asked0, asked1, noise0, noise1)
|
||||
asked0 = asked0 or zeros(self.dims)
|
||||
asked1 = asked1 or zeros(self.dims)
|
||||
noise0 = noise0 or {}
|
||||
noise1 = noise1 or {}
|
||||
|
||||
for i=1, self.dims do noise0[i] = normal() end
|
||||
noise0.shape = {#noise0}
|
||||
|
||||
dot_mv(self.covars, noise0, asked0)
|
||||
for i, v in ipairs(asked0) do
|
||||
asked0[i] = self.mean[i] + self.sigma * v
|
||||
asked1[i] = self.mean[i] - self.sigma * v
|
||||
end
|
||||
for i, v in ipairs(noise0) do noise1[i] = -v end
|
||||
|
||||
return asked0, asked1, noise0, noise1
|
||||
end
|
||||
|
||||
function Xnes:ask(asked, noise)
|
||||
-- return a list of parameters for the user to score,
|
||||
-- and later pass to :tell().
|
||||
|
@ -130,7 +152,17 @@ function Xnes:ask(asked, noise)
|
|||
noise = {}
|
||||
for i=1, self.popsize do noise[i] = zeros(self.dims) end
|
||||
end
|
||||
for i=1, self.popsize do self:ask_once(asked[i], noise[i]) end
|
||||
|
||||
if self.antithetic then
|
||||
for i=1, self.popsize do
|
||||
self:ask_twice(asked[i+0], asked[i+1], noise[i+0], noise[i+1])
|
||||
end
|
||||
else
|
||||
for i=1, self.popsize do
|
||||
self:ask_once(asked[i], noise[i])
|
||||
end
|
||||
end
|
||||
|
||||
self.noise = noise
|
||||
return asked, noise
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue