add antithetic sampling for xNES
This commit is contained in:
parent
695730335c
commit
0100934ac4
3 changed files with 36 additions and 8 deletions
|
@ -98,8 +98,6 @@ assert(not cfg.ars_lips or cfg.unperturbed_trial,
|
||||||
"cfg.unperturbed_trial must be true to use cfg.ars_lips")
|
"cfg.unperturbed_trial must be true to use cfg.ars_lips")
|
||||||
assert(not cfg.ars_lips or cfg.negate_trials,
|
assert(not cfg.ars_lips or cfg.negate_trials,
|
||||||
"cfg.negate_trials must be true to use cfg.ars_lips")
|
"cfg.negate_trials must be true to use cfg.ars_lips")
|
||||||
assert(not cfg.es == 'xnes' or not cfg.negate_trials,
|
|
||||||
"cfg.negate_trials is not yet compatible with xNES")
|
|
||||||
|
|
||||||
assert(not cfg.adamant,
|
assert(not cfg.adamant,
|
||||||
"cfg.adamant not yet re-implemented")
|
"cfg.adamant not yet re-implemented")
|
||||||
|
|
6
main.lua
6
main.lua
|
@ -452,10 +452,8 @@ local function init()
|
||||||
if cfg.es == 'xnes' then
|
if cfg.es == 'xnes' then
|
||||||
-- if you get an out of memory error, you can't use xNES. sorry!
|
-- if you get an out of memory error, you can't use xNES. sorry!
|
||||||
-- maybe there'll be a patch for FCEUX in the future.
|
-- maybe there'll be a patch for FCEUX in the future.
|
||||||
local trials = cfg.epoch_trials
|
es = xnes.Xnes(network.n_param, cfg.epoch_trials,
|
||||||
if cfg.negate_trials then trials = trials * 2 end
|
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
||||||
es = xnes.Xnes(network.n_param, trials, cfg.learning_rate,
|
|
||||||
cfg.deviation, cfg.negate_trials)
|
|
||||||
elseif cfg.es == 'ars' then
|
elseif cfg.es == 'ars' then
|
||||||
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
||||||
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
||||||
|
|
36
xnes.lua
36
xnes.lua
|
@ -75,12 +75,15 @@ local function make_covars(dims, sigma, out)
|
||||||
return covars
|
return covars
|
||||||
end
|
end
|
||||||
|
|
||||||
function Xnes:init(dims, popsize, learning_rate, sigma)
|
function Xnes:init(dims, popsize, learning_rate, sigma, antithetic)
|
||||||
-- heuristic borrowed from CMA-ES:
|
-- heuristic borrowed from CMA-ES:
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||||
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||||
self.sigma = sigma or 1
|
self.sigma = sigma or 1
|
||||||
|
self.antithetic = antithetic and true or false
|
||||||
|
|
||||||
|
if self.antithetic then self.popsize = self.popsize * 2 end
|
||||||
|
|
||||||
self.utility = make_utility(self.popsize)
|
self.utility = make_utility(self.popsize)
|
||||||
|
|
||||||
|
@ -119,6 +122,25 @@ function Xnes:ask_once(asked, noise)
|
||||||
return asked, noise
|
return asked, noise
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Xnes:ask_twice(asked0, asked1, noise0, noise1)
|
||||||
|
asked0 = asked0 or zeros(self.dims)
|
||||||
|
asked1 = asked1 or zeros(self.dims)
|
||||||
|
noise0 = noise0 or {}
|
||||||
|
noise1 = noise1 or {}
|
||||||
|
|
||||||
|
for i=1, self.dims do noise0[i] = normal() end
|
||||||
|
noise0.shape = {#noise0}
|
||||||
|
|
||||||
|
dot_mv(self.covars, noise0, asked0)
|
||||||
|
for i, v in ipairs(asked0) do
|
||||||
|
asked0[i] = self.mean[i] + self.sigma * v
|
||||||
|
asked1[i] = self.mean[i] - self.sigma * v
|
||||||
|
end
|
||||||
|
for i, v in ipairs(noise0) do noise1[i] = -v end
|
||||||
|
|
||||||
|
return asked0, asked1, noise0, noise1
|
||||||
|
end
|
||||||
|
|
||||||
function Xnes:ask(asked, noise)
|
function Xnes:ask(asked, noise)
|
||||||
-- return a list of parameters for the user to score,
|
-- return a list of parameters for the user to score,
|
||||||
-- and later pass to :tell().
|
-- and later pass to :tell().
|
||||||
|
@ -130,7 +152,17 @@ function Xnes:ask(asked, noise)
|
||||||
noise = {}
|
noise = {}
|
||||||
for i=1, self.popsize do noise[i] = zeros(self.dims) end
|
for i=1, self.popsize do noise[i] = zeros(self.dims) end
|
||||||
end
|
end
|
||||||
for i=1, self.popsize do self:ask_once(asked[i], noise[i]) end
|
|
||||||
|
if self.antithetic then
|
||||||
|
for i=1, self.popsize do
|
||||||
|
self:ask_twice(asked[i+0], asked[i+1], noise[i+0], noise[i+1])
|
||||||
|
end
|
||||||
|
else
|
||||||
|
for i=1, self.popsize do
|
||||||
|
self:ask_once(asked[i], noise[i])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
self.noise = noise
|
self.noise = noise
|
||||||
return asked, noise
|
return asked, noise
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in a new issue