sigma tweaks
This commit is contained in:
parent
12098ee592
commit
ccce6a2d55
2 changed files with 18 additions and 9 deletions
|
@ -77,6 +77,7 @@ local cfg = {
|
||||||
mean_adapt = 0.5,
|
mean_adapt = 0.5,
|
||||||
deviation = 0.5,
|
deviation = 0.5,
|
||||||
weight_decay = 0.025,
|
weight_decay = 0.025,
|
||||||
|
sigma_decay = 0.001,
|
||||||
}
|
}
|
||||||
|
|
||||||
-- TODO: so, uhh..
|
-- TODO: so, uhh..
|
||||||
|
|
20
main.lua
20
main.lua
|
@ -198,7 +198,8 @@ local function prepare_epoch()
|
||||||
print("sigma:", es.sigma)
|
print("sigma:", es.sigma)
|
||||||
elseif cfg.es == 'snes' then
|
elseif cfg.es == 'snes' then
|
||||||
local sigma_mean, sigma_dev = calc_mean_dev(es.std)
|
local sigma_mean, sigma_dev = calc_mean_dev(es.std)
|
||||||
print("sigma:", sigma_mean, sigma_dev)
|
--print("sigma:", sigma_mean, sigma_dev)
|
||||||
|
print("sigma 95%:", sigma_mean + sigma_dev * 1.64485)
|
||||||
end
|
end
|
||||||
|
|
||||||
local precision
|
local precision
|
||||||
|
@ -283,16 +284,23 @@ local function learn_from_epoch()
|
||||||
|
|
||||||
base_params = es:params()
|
base_params = es:params()
|
||||||
|
|
||||||
|
if cfg.es == 'snes' then
|
||||||
|
if cfg.sigma_decay > 0 then
|
||||||
|
for i, v in ipairs(es.std) do
|
||||||
|
es.std[i] = v * (1 - cfg.sigma_decay)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
if cfg.weight_decay > 0 then
|
||||||
|
for i, v in ipairs(base_params) do
|
||||||
|
base_params[i] = v * (1 - cfg.weight_decay * es.std[i])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
else
|
||||||
if cfg.weight_decay > 0 then
|
if cfg.weight_decay > 0 then
|
||||||
for i, v in ipairs(base_params) do
|
for i, v in ipairs(base_params) do
|
||||||
base_params[i] = v * (1 - cfg.weight_decay)
|
base_params[i] = v * (1 - cfg.weight_decay)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
if cfg.sigma_decay > 0 and cfg.es == 'snes' then
|
|
||||||
for i, v in ipairs(es.std) do
|
|
||||||
es.std[i] = v * (1 - cfg.sigma_decay)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
es:params(base_params)
|
es:params(base_params)
|
||||||
|
|
Loading…
Add table
Reference in a new issue