add PowerSign momentum to ARS, antithetic by default

This commit is contained in:
Connor Olding 2018-06-21 05:14:45 +02:00
parent dc235f5d18
commit 102eefe98c
4 changed files with 27 additions and 3 deletions

20
ars.lua
View file

@ -24,9 +24,15 @@ local util = require "util"
local argsort = util.argsort
local calc_mean_dev = util.calc_mean_dev
local normalize_sums = util.normalize_sums
local sign = util.sign
local Ars = Base:extend()
local exp_lut = {}
exp_lut[-1] = exp(-1)
exp_lut[0] = exp(0)
exp_lut[1] = exp(1)
local function collect_best_indices(scored, top, antithetic)
-- select one (the best) reward of each pos/neg pair.
local best_rewards
@ -60,7 +66,8 @@ local function kinda_lipschitz(dir, pos, neg, mid)
return max(l0, l1) / (2 * dev)
end
function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic)
function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic,
momentum)
self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims)))
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
@ -68,13 +75,15 @@ function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic)
self.sigma_rate = base_rate
self.covar_rate = base_rate
self.sigma = sigma or 1
self.antithetic = antithetic and true or false
self.antithetic = antithetic == nil and true or antithetic
self.momentum = momentum or 0
self.poptop = poptop or popsize
assert(self.poptop <= popsize)
if self.antithetic then self.popsize = self.popsize * 2 end
self._params = zeros(self.dims)
if self.momentum > 0 then self.accum = zeros(self.dims) end
self.evals = 0
end
@ -189,6 +198,13 @@ function Ars:tell(scored, unperturbed_score)
end
end
if self.momentum > 0 then
for i, v in ipairs(step) do
self.accum[i] = self.momentum * self.accum[i] + v
step[i] = v * exp_lut[sign(v) * sign(self.accum[i])]
end
end
for i, v in ipairs(self._params) do
self._params[i] = v + self.param_rate * step[i]
end

View file

@ -58,6 +58,7 @@ local defaults = {
-- if you don't specify them individually.
param_decay = 0.0,
sigma_decay = 0.0, -- for SNES, xNES.
momentum = 0.0, -- for ARS.
}
local presets = require("presets")

View file

@ -508,7 +508,8 @@ local function init()
end
elseif cfg.es == 'ars' then
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
cfg.base_rate, cfg.deviation, cfg.negate_trials)
cfg.base_rate, cfg.deviation, cfg.negate_trials,
cfg.momentum)
else
error("Unknown evolution strategy specified: " + tostring(cfg.es))
end

View file

@ -15,6 +15,11 @@ local select = select
local sort = table.sort
local sqrt = math.sqrt
local function sign(x)
-- remember that 0 is truthy in Lua.
return x == 0 and 0 or x > 0 and 1 or -1
end
local function signbyte(x)
if x >= 128 then x = 256 - x end
return x
@ -245,6 +250,7 @@ local function weighted_mann_whitney(s0, s1, w0, w1)
end
return {
sign=sign,
signbyte=signbyte,
boolean_xor=boolean_xor,
log2=log2,