diff --git a/config.lua b/config.lua index 5fa567b..096d7a6 100644 --- a/config.lua +++ b/config.lua @@ -98,8 +98,6 @@ assert(not cfg.ars_lips or cfg.unperturbed_trial, "cfg.unperturbed_trial must be true to use cfg.ars_lips") assert(not cfg.ars_lips or cfg.negate_trials, "cfg.negate_trials must be true to use cfg.ars_lips") -assert(not cfg.es == 'xnes' or not cfg.negate_trials, - "cfg.negate_trials is not yet compatible with xNES") assert(not cfg.adamant, "cfg.adamant not yet re-implemented") diff --git a/main.lua b/main.lua index d24d744..ba79f7d 100644 --- a/main.lua +++ b/main.lua @@ -452,10 +452,8 @@ local function init() if cfg.es == 'xnes' then -- if you get an out of memory error, you can't use xNES. sorry! -- maybe there'll be a patch for FCEUX in the future. - local trials = cfg.epoch_trials - if cfg.negate_trials then trials = trials * 2 end - es = xnes.Xnes(network.n_param, trials, cfg.learning_rate, - cfg.deviation, cfg.negate_trials) + es = xnes.Xnes(network.n_param, cfg.epoch_trials, + cfg.learning_rate, cfg.deviation, cfg.negate_trials) elseif cfg.es == 'ars' then es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials, cfg.learning_rate, cfg.deviation, cfg.negate_trials) diff --git a/xnes.lua b/xnes.lua index b3c2fe5..cc2c942 100644 --- a/xnes.lua +++ b/xnes.lua @@ -75,12 +75,15 @@ local function make_covars(dims, sigma, out) return covars end -function Xnes:init(dims, popsize, learning_rate, sigma) +function Xnes:init(dims, popsize, learning_rate, sigma, antithetic) -- heuristic borrowed from CMA-ES: self.dims = dims self.popsize = popsize or 4 + (3 * floor(log(dims))) self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) self.sigma = sigma or 1 + self.antithetic = antithetic and true or false + + if self.antithetic then self.popsize = self.popsize * 2 end self.utility = make_utility(self.popsize) @@ -119,6 +122,25 @@ function Xnes:ask_once(asked, noise) return asked, noise end +function Xnes:ask_twice(asked0, asked1, noise0, noise1) + asked0 = asked0 or zeros(self.dims) + asked1 = asked1 or zeros(self.dims) + noise0 = noise0 or {} + noise1 = noise1 or {} + + for i=1, self.dims do noise0[i] = normal() end + noise0.shape = {#noise0} + + dot_mv(self.covars, noise0, asked0) + for i, v in ipairs(asked0) do + asked0[i] = self.mean[i] + self.sigma * v + asked1[i] = self.mean[i] - self.sigma * v + end + for i, v in ipairs(noise0) do noise1[i] = -v end + + return asked0, asked1, noise0, noise1 +end + function Xnes:ask(asked, noise) -- return a list of parameters for the user to score, -- and later pass to :tell(). @@ -130,7 +152,17 @@ function Xnes:ask(asked, noise) noise = {} for i=1, self.popsize do noise[i] = zeros(self.dims) end end - for i=1, self.popsize do self:ask_once(asked[i], noise[i]) end + + if self.antithetic then + for i=1, self.popsize do + self:ask_twice(asked[i+0], asked[i+1], noise[i+0], noise[i+1]) + end + else + for i=1, self.popsize do + self:ask_once(asked[i], noise[i]) + end + end + self.noise = noise return asked, noise end