diff --git a/ars.lua b/ars.lua index ef4df1a..28b6176 100644 --- a/ars.lua +++ b/ars.lua @@ -24,9 +24,15 @@ local util = require "util" local argsort = util.argsort local calc_mean_dev = util.calc_mean_dev local normalize_sums = util.normalize_sums +local sign = util.sign local Ars = Base:extend() +local exp_lut = {} +exp_lut[-1] = exp(-1) +exp_lut[0] = exp(0) +exp_lut[1] = exp(1) + local function collect_best_indices(scored, top, antithetic) -- select one (the best) reward of each pos/neg pair. local best_rewards @@ -60,7 +66,8 @@ local function kinda_lipschitz(dir, pos, neg, mid) return max(l0, l1) / (2 * dev) end -function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic) +function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic, + momentum) self.dims = dims self.popsize = popsize or 4 + (3 * floor(log(dims))) base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) @@ -68,13 +75,15 @@ function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic) self.sigma_rate = base_rate self.covar_rate = base_rate self.sigma = sigma or 1 - self.antithetic = antithetic and true or false + self.antithetic = antithetic == nil and true or antithetic + self.momentum = momentum or 0 self.poptop = poptop or popsize assert(self.poptop <= popsize) if self.antithetic then self.popsize = self.popsize * 2 end self._params = zeros(self.dims) + if self.momentum > 0 then self.accum = zeros(self.dims) end self.evals = 0 end @@ -189,6 +198,13 @@ function Ars:tell(scored, unperturbed_score) end end + if self.momentum > 0 then + for i, v in ipairs(step) do + self.accum[i] = self.momentum * self.accum[i] + v + step[i] = v * exp_lut[sign(v) * sign(self.accum[i])] + end + end + for i, v in ipairs(self._params) do self._params[i] = v + self.param_rate * step[i] end diff --git a/config.lua b/config.lua index 562f686..63746f6 100644 --- a/config.lua +++ b/config.lua @@ -58,6 +58,7 @@ local defaults = { -- if you don't specify them individually. param_decay = 0.0, sigma_decay = 0.0, -- for SNES, xNES. + momentum = 0.0, -- for ARS. } local presets = require("presets") diff --git a/main.lua b/main.lua index 7a62b45..5eb3d24 100644 --- a/main.lua +++ b/main.lua @@ -508,7 +508,8 @@ local function init() end elseif cfg.es == 'ars' then es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials, - cfg.base_rate, cfg.deviation, cfg.negate_trials) + cfg.base_rate, cfg.deviation, cfg.negate_trials, + cfg.momentum) else error("Unknown evolution strategy specified: " + tostring(cfg.es)) end diff --git a/util.lua b/util.lua index 6ab8daf..1b98428 100644 --- a/util.lua +++ b/util.lua @@ -15,6 +15,11 @@ local select = select local sort = table.sort local sqrt = math.sqrt +local function sign(x) + -- remember that 0 is truthy in Lua. + return x == 0 and 0 or x > 0 and 1 or -1 +end + local function signbyte(x) if x >= 128 then x = 256 - x end return x @@ -245,6 +250,7 @@ local function weighted_mann_whitney(s0, s1, w0, w1) end return { + sign=sign, signbyte=signbyte, boolean_xor=boolean_xor, log2=log2,