sigma tweaks

This commit is contained in:
Connor Olding 2018-06-12 05:39:22 +02:00
parent 12098ee592
commit ccce6a2d55
2 changed files with 18 additions and 9 deletions

View file

@ -77,6 +77,7 @@ local cfg = {
mean_adapt = 0.5,
deviation = 0.5,
weight_decay = 0.025,
sigma_decay = 0.001,
}
-- TODO: so, uhh..

View file

@ -198,7 +198,8 @@ local function prepare_epoch()
print("sigma:", es.sigma)
elseif cfg.es == 'snes' then
local sigma_mean, sigma_dev = calc_mean_dev(es.std)
print("sigma:", sigma_mean, sigma_dev)
--print("sigma:", sigma_mean, sigma_dev)
print("sigma 95%:", sigma_mean + sigma_dev * 1.64485)
end
local precision
@ -283,15 +284,22 @@ local function learn_from_epoch()
base_params = es:params()
if cfg.weight_decay > 0 then
for i, v in ipairs(base_params) do
base_params[i] = v * (1 - cfg.weight_decay)
if cfg.es == 'snes' then
if cfg.sigma_decay > 0 then
for i, v in ipairs(es.std) do
es.std[i] = v * (1 - cfg.sigma_decay)
end
end
end
if cfg.sigma_decay > 0 and cfg.es == 'snes' then
for i, v in ipairs(es.std) do
es.std[i] = v * (1 - cfg.sigma_decay)
if cfg.weight_decay > 0 then
for i, v in ipairs(base_params) do
base_params[i] = v * (1 - cfg.weight_decay * es.std[i])
end
end
else
if cfg.weight_decay > 0 then
for i, v in ipairs(base_params) do
base_params[i] = v * (1 - cfg.weight_decay)
end
end
end