From d87b8e7118de1ff2dfa1d53037973826f877234f Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Sun, 10 Jun 2018 16:38:25 +0200 Subject: [PATCH] add mean adaptation hyperparameter --- config.lua | 3 +++ main.lua | 2 ++ xnes.lua | 10 ++++------ 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/config.lua b/config.lua index 096d7a6..5ccf87d 100644 --- a/config.lua +++ b/config.lua @@ -32,6 +32,9 @@ local common_cfg = { time_inputs = true, -- binary inputs of global frame count normalize_inputs = false, + learning_rate = 1.0, + mean_adapt = 1.0, -- for xNES + es = 'ars', ars_lips = false, adamant = false, -- run steps through AMSgrad. diff --git a/main.lua b/main.lua index b3be9c3..1cf985f 100644 --- a/main.lua +++ b/main.lua @@ -456,6 +456,8 @@ local function init() -- maybe there'll be a patch for FCEUX in the future. es = xnes.Xnes(network.n_param, cfg.epoch_trials, cfg.learning_rate, cfg.deviation, cfg.negate_trials) + -- TODO: clean this up into an interface: + es.mean_adapt = cfg.mean_adapt elseif cfg.es == 'ars' then es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials, cfg.learning_rate, cfg.deviation, cfg.negate_trials) diff --git a/xnes.lua b/xnes.lua index cc33599..fb04695 100644 --- a/xnes.lua +++ b/xnes.lua @@ -71,17 +71,15 @@ function Xnes:init(dims, popsize, learning_rate, sigma, antithetic) --self.log_sigma = log(self.sigma) --self.log_covars = zeros{dims, dims} --for i, v in ipairs(self.covars) do self.log_covars[i] = log(v) end + + self.mean_adapt = 1.0 end -function Xnes:params(new_mean, new_covars) +function Xnes:params(new_mean) if new_mean ~= nil then assert(#self.mean == #new_mean, "new parameters have the wrong size") for i, v in ipairs(new_mean) do self.mean[i] = v end end - if new_covars ~= nil then - -- TODO: assert determinant of new_covars is 1. - error("TODO") - end return self.mean end @@ -183,7 +181,7 @@ function Xnes:tell(scored, noise) local dotted = dot_mv(self.covars, g_delta) for i, v in ipairs(self.mean) do - self.mean[i] = v + self.sigma * dotted[i] + self.mean[i] = v + self.mean_adapt * self.sigma * dotted[i] end --[[