diff --git a/config.lua b/config.lua index 96c56ff..89e8586 100644 --- a/config.lua +++ b/config.lua @@ -34,6 +34,8 @@ local common_cfg = { learning_rate = 1.0, mean_adapt = 1.0, -- for xNES + weight_decay = 0.0, + sigma_decay = 0.0, es = 'ars', ars_lips = false, diff --git a/main.lua b/main.lua index 21d3467..8db1d99 100644 --- a/main.lua +++ b/main.lua @@ -191,6 +191,13 @@ local function prepare_epoch() print('preparing epoch '..tostring(epoch_i)..'...') empty(trial_rewards) + if cfg.es == 'xnes' then + print("sigma:", es.sigma) + elseif cfg.es == 'snes' then + local sigma_mean, sigma_dev = calc_mean_dev(es.std) + print("sigma:", sigma_mean, sigma_dev) + end + local precision if cfg.graycode then precision = (pow(cfg.deviation, 1/-0.51175585) - 8.68297257) / 1.66484392 @@ -254,13 +261,6 @@ local function learn_from_epoch() es:tell(trial_rewards) end - if cfg.es == 'xnes' then - print("sigma:", es.sigma) - elseif cfg.es == 'snes' then - local sigma_mean, sigma_dev = calc_mean_dev(es.std) - print("sigma:", sigma_mean, sigma_dev) - end - local step_mean, step_dev = 0, 0 --[[ TODO local step_mean, step_dev = calc_mean_dev(step) @@ -286,6 +286,12 @@ local function learn_from_epoch() end end + if cfg.sigma_decay > 0 and cfg.es == 'snes' then + for i, v in ipairs(es.std) do + es.std[i] = v * (1 - cfg.sigma_decay) + end + end + es:params(base_params) local trial_mean, trial_std = calc_mean_dev(trial_rewards)