add sigma decay; move printing to start of epoch

This commit is contained in:
Connor Olding 2018-06-10 19:34:17 +02:00
parent 56f7c01256
commit fa0287d966
2 changed files with 15 additions and 7 deletions

View File

@ -34,6 +34,8 @@ local common_cfg = {
learning_rate = 1.0, learning_rate = 1.0,
mean_adapt = 1.0, -- for xNES mean_adapt = 1.0, -- for xNES
weight_decay = 0.0,
sigma_decay = 0.0,
es = 'ars', es = 'ars',
ars_lips = false, ars_lips = false,

View File

@ -191,6 +191,13 @@ local function prepare_epoch()
print('preparing epoch '..tostring(epoch_i)..'...') print('preparing epoch '..tostring(epoch_i)..'...')
empty(trial_rewards) empty(trial_rewards)
if cfg.es == 'xnes' then
print("sigma:", es.sigma)
elseif cfg.es == 'snes' then
local sigma_mean, sigma_dev = calc_mean_dev(es.std)
print("sigma:", sigma_mean, sigma_dev)
end
local precision local precision
if cfg.graycode then if cfg.graycode then
precision = (pow(cfg.deviation, 1/-0.51175585) - 8.68297257) / 1.66484392 precision = (pow(cfg.deviation, 1/-0.51175585) - 8.68297257) / 1.66484392
@ -254,13 +261,6 @@ local function learn_from_epoch()
es:tell(trial_rewards) es:tell(trial_rewards)
end end
if cfg.es == 'xnes' then
print("sigma:", es.sigma)
elseif cfg.es == 'snes' then
local sigma_mean, sigma_dev = calc_mean_dev(es.std)
print("sigma:", sigma_mean, sigma_dev)
end
local step_mean, step_dev = 0, 0 local step_mean, step_dev = 0, 0
--[[ TODO --[[ TODO
local step_mean, step_dev = calc_mean_dev(step) local step_mean, step_dev = calc_mean_dev(step)
@ -286,6 +286,12 @@ local function learn_from_epoch()
end end
end end
if cfg.sigma_decay > 0 and cfg.es == 'snes' then
for i, v in ipairs(es.std) do
es.std[i] = v * (1 - cfg.sigma_decay)
end
end
es:params(base_params) es:params(base_params)
local trial_mean, trial_std = calc_mean_dev(trial_rewards) local trial_mean, trial_std = calc_mean_dev(trial_rewards)