add sigma decay to xNES
This commit is contained in:
parent
422468dd47
commit
f512f8ac3a
2 changed files with 10 additions and 1 deletions
|
@ -52,7 +52,7 @@ local common_cfg = {
|
||||||
learning_rate = 1.0,
|
learning_rate = 1.0,
|
||||||
mean_adapt = 1.0, -- for SNES, xNES.
|
mean_adapt = 1.0, -- for SNES, xNES.
|
||||||
weight_decay = 0.0,
|
weight_decay = 0.0,
|
||||||
sigma_decay = 0.0, -- for SNES.
|
sigma_decay = 0.0, -- for SNES, xNES.
|
||||||
}
|
}
|
||||||
|
|
||||||
local cfg
|
local cfg
|
||||||
|
|
9
main.lua
9
main.lua
|
@ -287,6 +287,15 @@ local function learn_from_epoch()
|
||||||
base_params[i] = v * (1 - cfg.mean_adapt * cfg.weight_decay * es.std[i])
|
base_params[i] = v * (1 - cfg.mean_adapt * cfg.weight_decay * es.std[i])
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
elseif cfg.es == 'xnes' then
|
||||||
|
if cfg.sigma_decay > 0 then
|
||||||
|
es.sigma = es.sigma * (1 - cfg.sigma_decay)
|
||||||
|
end
|
||||||
|
if cfg.weight_decay > 0 then
|
||||||
|
for i, v in ipairs(base_params) do
|
||||||
|
base_params[i] = v * (1 - cfg.mean_adapt * cfg.weight_decay)
|
||||||
|
end
|
||||||
|
end
|
||||||
else
|
else
|
||||||
if cfg.weight_decay > 0 then
|
if cfg.weight_decay > 0 then
|
||||||
for i, v in ipairs(base_params) do
|
for i, v in ipairs(base_params) do
|
||||||
|
|
Loading…
Add table
Reference in a new issue