diff --git a/config.lua b/config.lua index 89e8586..179e830 100644 --- a/config.lua +++ b/config.lua @@ -77,6 +77,7 @@ local cfg = { mean_adapt = 0.5, deviation = 0.5, weight_decay = 0.025, + sigma_decay = 0.001, } -- TODO: so, uhh.. diff --git a/main.lua b/main.lua index f00eaed..f7a929f 100644 --- a/main.lua +++ b/main.lua @@ -198,7 +198,8 @@ local function prepare_epoch() print("sigma:", es.sigma) elseif cfg.es == 'snes' then local sigma_mean, sigma_dev = calc_mean_dev(es.std) - print("sigma:", sigma_mean, sigma_dev) + --print("sigma:", sigma_mean, sigma_dev) + print("sigma 95%:", sigma_mean + sigma_dev * 1.64485) end local precision @@ -283,15 +284,22 @@ local function learn_from_epoch() base_params = es:params() - if cfg.weight_decay > 0 then - for i, v in ipairs(base_params) do - base_params[i] = v * (1 - cfg.weight_decay) + if cfg.es == 'snes' then + if cfg.sigma_decay > 0 then + for i, v in ipairs(es.std) do + es.std[i] = v * (1 - cfg.sigma_decay) + end end - end - - if cfg.sigma_decay > 0 and cfg.es == 'snes' then - for i, v in ipairs(es.std) do - es.std[i] = v * (1 - cfg.sigma_decay) + if cfg.weight_decay > 0 then + for i, v in ipairs(base_params) do + base_params[i] = v * (1 - cfg.weight_decay * es.std[i]) + end + end + else + if cfg.weight_decay > 0 then + for i, v in ipairs(base_params) do + base_params[i] = v * (1 - cfg.weight_decay) + end end end