add sigma decay to xNES

This commit is contained in:
Connor Olding 2018-06-14 22:25:54 +02:00
parent 422468dd47
commit f512f8ac3a
2 changed files with 10 additions and 1 deletions

View file

@ -52,7 +52,7 @@ local common_cfg = {
learning_rate = 1.0,
mean_adapt = 1.0, -- for SNES, xNES.
weight_decay = 0.0,
sigma_decay = 0.0, -- for SNES.
sigma_decay = 0.0, -- for SNES, xNES.
}
local cfg

View file

@ -287,6 +287,15 @@ local function learn_from_epoch()
base_params[i] = v * (1 - cfg.mean_adapt * cfg.weight_decay * es.std[i])
end
end
elseif cfg.es == 'xnes' then
if cfg.sigma_decay > 0 then
es.sigma = es.sigma * (1 - cfg.sigma_decay)
end
if cfg.weight_decay > 0 then
for i, v in ipairs(base_params) do
base_params[i] = v * (1 - cfg.mean_adapt * cfg.weight_decay)
end
end
else
if cfg.weight_decay > 0 then
for i, v in ipairs(base_params) do