add PowerSign momentum to ARS, antithetic by default
This commit is contained in:
parent
dc235f5d18
commit
102eefe98c
4 changed files with 27 additions and 3 deletions
20
ars.lua
20
ars.lua
|
@ -24,9 +24,15 @@ local util = require "util"
|
|||
local argsort = util.argsort
|
||||
local calc_mean_dev = util.calc_mean_dev
|
||||
local normalize_sums = util.normalize_sums
|
||||
local sign = util.sign
|
||||
|
||||
local Ars = Base:extend()
|
||||
|
||||
local exp_lut = {}
|
||||
exp_lut[-1] = exp(-1)
|
||||
exp_lut[0] = exp(0)
|
||||
exp_lut[1] = exp(1)
|
||||
|
||||
local function collect_best_indices(scored, top, antithetic)
|
||||
-- select one (the best) reward of each pos/neg pair.
|
||||
local best_rewards
|
||||
|
@ -60,7 +66,8 @@ local function kinda_lipschitz(dir, pos, neg, mid)
|
|||
return max(l0, l1) / (2 * dev)
|
||||
end
|
||||
|
||||
function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic)
|
||||
function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic,
|
||||
momentum)
|
||||
self.dims = dims
|
||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||
|
@ -68,13 +75,15 @@ function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic)
|
|||
self.sigma_rate = base_rate
|
||||
self.covar_rate = base_rate
|
||||
self.sigma = sigma or 1
|
||||
self.antithetic = antithetic and true or false
|
||||
self.antithetic = antithetic == nil and true or antithetic
|
||||
self.momentum = momentum or 0
|
||||
|
||||
self.poptop = poptop or popsize
|
||||
assert(self.poptop <= popsize)
|
||||
if self.antithetic then self.popsize = self.popsize * 2 end
|
||||
|
||||
self._params = zeros(self.dims)
|
||||
if self.momentum > 0 then self.accum = zeros(self.dims) end
|
||||
|
||||
self.evals = 0
|
||||
end
|
||||
|
@ -189,6 +198,13 @@ function Ars:tell(scored, unperturbed_score)
|
|||
end
|
||||
end
|
||||
|
||||
if self.momentum > 0 then
|
||||
for i, v in ipairs(step) do
|
||||
self.accum[i] = self.momentum * self.accum[i] + v
|
||||
step[i] = v * exp_lut[sign(v) * sign(self.accum[i])]
|
||||
end
|
||||
end
|
||||
|
||||
for i, v in ipairs(self._params) do
|
||||
self._params[i] = v + self.param_rate * step[i]
|
||||
end
|
||||
|
|
|
@ -58,6 +58,7 @@ local defaults = {
|
|||
-- if you don't specify them individually.
|
||||
param_decay = 0.0,
|
||||
sigma_decay = 0.0, -- for SNES, xNES.
|
||||
momentum = 0.0, -- for ARS.
|
||||
}
|
||||
|
||||
local presets = require("presets")
|
||||
|
|
3
main.lua
3
main.lua
|
@ -508,7 +508,8 @@ local function init()
|
|||
end
|
||||
elseif cfg.es == 'ars' then
|
||||
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
||||
cfg.base_rate, cfg.deviation, cfg.negate_trials)
|
||||
cfg.base_rate, cfg.deviation, cfg.negate_trials,
|
||||
cfg.momentum)
|
||||
else
|
||||
error("Unknown evolution strategy specified: " + tostring(cfg.es))
|
||||
end
|
||||
|
|
6
util.lua
6
util.lua
|
@ -15,6 +15,11 @@ local select = select
|
|||
local sort = table.sort
|
||||
local sqrt = math.sqrt
|
||||
|
||||
local function sign(x)
|
||||
-- remember that 0 is truthy in Lua.
|
||||
return x == 0 and 0 or x > 0 and 1 or -1
|
||||
end
|
||||
|
||||
local function signbyte(x)
|
||||
if x >= 128 then x = 256 - x end
|
||||
return x
|
||||
|
@ -245,6 +250,7 @@ local function weighted_mann_whitney(s0, s1, w0, w1)
|
|||
end
|
||||
|
||||
return {
|
||||
sign=sign,
|
||||
signbyte=signbyte,
|
||||
boolean_xor=boolean_xor,
|
||||
log2=log2,
|
||||
|
|
Loading…
Reference in a new issue