From f3fc95404cd61002d7f93962e656db6baa838098 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Sat, 16 Jun 2018 00:24:55 +0200 Subject: [PATCH] overhaul learning rates: - rename mean_adapt to weight_rate - sigma and covar update rates can be specified separately (sigma_rate, covar_rate) - base decays on current rates instead of initially configured rates (this might break stuff) - base_rate takes the place of learning_rate --- ars.lua | 9 ++++++--- config.lua | 37 +++++++++++++++++++++++++------------ main.lua | 21 ++++++++++++--------- snes.lua | 25 +++++++++++++------------ xnes.lua | 16 ++++++++-------- 5 files changed, 64 insertions(+), 44 deletions(-) diff --git a/ars.lua b/ars.lua index 2cc6422..03ac486 100644 --- a/ars.lua +++ b/ars.lua @@ -60,10 +60,13 @@ local function kinda_lipschitz(dir, pos, neg, mid) return max(l0, l1) / (2 * dev) end -function Ars:init(dims, popsize, poptop, learning_rate, sigma, antithetic) +function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic) self.dims = dims self.popsize = popsize or 4 + (3 * floor(log(dims))) - self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) + base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) + self.weight_rate = base_rate + self.sigma_rate = base_rate + self.covar_rate = base_rate self.sigma = sigma or 1 self.antithetic = antithetic and true or false @@ -187,7 +190,7 @@ function Ars:tell(scored, unperturbed_score) end for i, v in ipairs(self._params) do - self._params[i] = v + self.learning_rate * step[i] + self._params[i] = v + self.weight_rate * step[i] end self.noise = nil diff --git a/config.lua b/config.lua index 2b29b38..1420019 100644 --- a/config.lua +++ b/config.lua @@ -50,8 +50,8 @@ local common_cfg = { min_refresh = 0.1, -- for SNES. -- epoch-related rates: - learning_rate = 1.0, - mean_adapt = 1.0, -- for SNES, xNES. + base_rate = 1.0, -- weight_rate, sigma_rate, and covar_rate will base on this + -- if you don't specify them individually. weight_decay = 0.0, sigma_decay = 0.0, -- for SNES, xNES. } @@ -75,10 +75,10 @@ if preset == 'snes' then deviation = 1.0, min_refresh = 0.2, - learning_rate = 0.1, -- TODO: rename to learn_primary or something. - mean_adapt = 0.5, -- TODO: rename to learn_secondary or something. - weight_decay = 0.02, -- note: multiplied by its std, and mean_adapt. - sigma_decay = 0.01, -- note: multiplied by learning_rate. + weight_rate = 0.5, + sigma_rate = 0.1, + weight_decay = 0.02, -- note: multiplied by its std, and weight_rate. + sigma_decay = 0.01, -- note: multiplied by sigma_rate. } elseif preset == 'snes2' then @@ -108,8 +108,8 @@ elseif preset == 'snes2' then epoch_trials = 100, - learning_rate = 0.01, - mean_adapt = 1.0, + weight_rate = 1.0, + sigma_rate = 0.01, weight_decay = 0.02, sigma_decay = 0.01, } @@ -140,7 +140,9 @@ elseif preset == 'xnes' then epoch_trials = 50, - learning_rate = 0.01, + weight_rate = 1.0, + sigma_rate = 0.01, + covar_rate = 0.01, } elseif preset == 'xnes2' then @@ -169,8 +171,9 @@ elseif preset == 'xnes2' then epoch_trials = 10, --50, - learning_rate = 0.04, - mean_adapt = 0.5, + weight_rate = 0.5, + sigma_rate = 0.04, + covar_rate = 0.04, weight_decay = 0.004, sigma_decay = 0.00128, } @@ -199,7 +202,7 @@ elseif preset == 'ars' then epoch_trials = 25 * 2, - learning_rate = 1.0, + weight_rate = 1.0, weight_decay = 0.0025, } @@ -218,6 +221,16 @@ end -- TODO: so, uhh.. -- what happens when playback_mode is true but unperturbed_trial is false? +setmetatable(cfg, {__index=common_cfg}) -- gets overridden later. + +if cfg.es == 'ars' then + if cfg.weight_rate == nil then cfg.weight_rate = cfg.base_rate end +else + if cfg.weight_rate == nil then cfg.weight_rate = 1.0 end +end +if cfg.sigma_rate == nil then cfg.sigma_rate = cfg.base_rate end +if cfg.covar_rate == nil then cfg.covar_rate = cfg.sigma_rate end + setmetatable(cfg, { __index = function(t, n) if common_cfg[n] ~= nil then return common_cfg[n] end diff --git a/main.lua b/main.lua index 068d0bc..c554d06 100644 --- a/main.lua +++ b/main.lua @@ -276,15 +276,16 @@ local function learn_from_epoch() base_params = es:params() + -- TODO: move this all to es:decay methods. if cfg.es == 'snes' then if cfg.sigma_decay > 0 then for i, v in ipairs(es.std) do - es.std[i] = v * (1 - cfg.learning_rate * cfg.sigma_decay) + es.std[i] = v * (1 - es.sigma_rate * cfg.sigma_decay) end end if cfg.weight_decay > 0 then for i, v in ipairs(base_params) do - base_params[i] = v * (1 - cfg.mean_adapt * cfg.weight_decay * es.std[i]) + base_params[i] = v * (1 - es.weight_rate * cfg.weight_decay * es.std[i]) end end elseif cfg.es == 'xnes' then @@ -293,7 +294,7 @@ local function learn_from_epoch() end if cfg.weight_decay > 0 then for i, v in ipairs(base_params) do - base_params[i] = v * (1 - cfg.mean_adapt * cfg.weight_decay) + base_params[i] = v * (1 - es.weight_rate * cfg.weight_decay) end end else @@ -494,14 +495,11 @@ local function init() -- if you get an out of memory error, you can't use xNES. sorry! -- maybe there'll be a patch for FCEUX in the future. es = xnes.Xnes(network.n_param, cfg.epoch_trials, - cfg.learning_rate, cfg.deviation, cfg.negate_trials) - -- TODO: clean this up into an interface: - es.mean_adapt = cfg.mean_adapt + cfg.base_rate, cfg.deviation, cfg.negate_trials) elseif cfg.es == 'snes' then es = snes.Snes(network.n_param, cfg.epoch_trials, - cfg.learning_rate, cfg.deviation, cfg.negate_trials) + cfg.base_rate, cfg.deviation, cfg.negate_trials) -- TODO: clean this up into an interface: - es.mean_adapt = cfg.mean_adapt es.min_refresh = cfg.min_refresh if exists(std_fn) then @@ -513,11 +511,16 @@ local function init() end elseif cfg.es == 'ars' then es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials, - cfg.learning_rate, cfg.deviation, cfg.negate_trials) + cfg.base_rate, cfg.deviation, cfg.negate_trials) else error("Unknown evolution strategy specified: " + tostring(cfg.es)) end + es.weight_rate = cfg.weight_rate + es.sigma_rate = cfg.sigma_rate + es.covar_rate = cfg.covar_rate + es.rate_init = cfg.sigma_rate -- just for SNES? + es:params(network:collect()) end diff --git a/snes.lua b/snes.lua index b1e5f32..164f6f6 100644 --- a/snes.lua +++ b/snes.lua @@ -32,24 +32,25 @@ local weighted_mann_whitney = util.weighted_mann_whitney local Snes = Base:extend() -function Snes:init(dims, popsize, learning_rate, sigma, antithetic) +function Snes:init(dims, popsize, base_rate, sigma, antithetic) -- heuristic borrowed from CMA-ES: self.dims = dims self.popsize = popsize or 4 + (3 * floor(log(dims))) - self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) + base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) + self.weight_rate = 1.0 + self.sigma_rate = base_rate + self.covar_rate = base_rate self.sigma = sigma or 1 self.antithetic = antithetic and true or false if self.antithetic then self.popsize = self.popsize * 2 end - self.rate_init = self.learning_rate + self.rate_init = self.sigma_rate self.mean = zeros{dims} self.std = zeros{dims} for i=1, self.dims do self.std[i] = self.sigma end - self.mean_adapt = 1.0 - self.old_asked = {} self.old_noise = {} self.old_score = {} @@ -250,15 +251,15 @@ function Snes:tell(scored) end for i, v in ipairs(self.mean) do - self.mean[i] = v + self.mean_adapt * step[i] + self.mean[i] = v + self.weight_rate * step[i] end local otherwise = {} self.std_old = {} for i, v in ipairs(self.std) do self.std_old[i] = v - self.std[i] = v * exp(self.learning_rate * 0.5 * g_std[i]) - otherwise[i] = v * exp(self.learning_rate * 0.75 * g_std[i]) + self.std[i] = v * exp(self.sigma_rate * 0.5 * g_std[i]) + otherwise[i] = v * exp(self.sigma_rate * 0.75 * g_std[i]) end self:adapt(asked, otherwise, utility) @@ -283,11 +284,11 @@ function Snes:adapt(asked, otherwise, qualities) --print("p:", p) if p < 0.5 - 1 / (3 * (self.dims + 1)) then - self.learning_rate = 0.9 * self.learning_rate + 0.1 * self.rate_init - print("learning rate -:", self.learning_rate) + self.sigma_rate = 0.9 * self.sigma_rate + 0.1 * self.rate_init + print("learning rate -:", self.sigma_rate) else - self.learning_rate = min(1.1 * self.learning_rate, 1) - print("learning rate +:", self.learning_rate) + self.sigma_rate = min(1.1 * self.sigma_rate, 1) + print("learning rate +:", self.sigma_rate) end end diff --git a/xnes.lua b/xnes.lua index b48a654..2be4c14 100644 --- a/xnes.lua +++ b/xnes.lua @@ -51,11 +51,14 @@ local function make_covars(dims, sigma, out) return covars end -function Xnes:init(dims, popsize, learning_rate, sigma, antithetic) +function Xnes:init(dims, popsize, base_rate, sigma, antithetic) -- heuristic borrowed from CMA-ES: self.dims = dims self.popsize = popsize or 4 + (3 * floor(log(dims))) - self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) + base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) + self.weight_rate = 1.0 + self.sigma_rate = base_rate + self.covar_rate = base_rate self.sigma = sigma or 1 self.antithetic = antithetic and true or false @@ -67,8 +70,6 @@ function Xnes:init(dims, popsize, learning_rate, sigma, antithetic) -- note: this is technically the co-standard-deviation. -- you can imagine the "s" standing for "sqrt" if you like. self.covars = make_covars(self.dims, self.sigma, self.covars) - - self.mean_adapt = 1.0 end function Xnes:params(new_mean) @@ -182,13 +183,12 @@ function Xnes:tell(scored, noise) end for i, v in ipairs(self.mean) do - self.mean[i] = v + self.mean_adapt * step[i] + self.mean[i] = v + self.weight_rate * step[i] end - local lr = self.learning_rate * 0.5 - self.sigma = self.sigma * exp(lr * g_sigma) + self.sigma = self.sigma * exp(self.sigma_rate * 0.5 * g_sigma) for i, v in ipairs(self.covars) do - self.covars[i] = v * exp(lr * g_covars[i]) + self.covars[i] = v * exp(self.covar_rate * 0.5 * g_covars[i]) end -- bookkeeping: