add sigma decay; move printing to start of epoch
This commit is contained in:
parent
56f7c01256
commit
fa0287d966
2 changed files with 15 additions and 7 deletions
|
@ -34,6 +34,8 @@ local common_cfg = {
|
||||||
|
|
||||||
learning_rate = 1.0,
|
learning_rate = 1.0,
|
||||||
mean_adapt = 1.0, -- for xNES
|
mean_adapt = 1.0, -- for xNES
|
||||||
|
weight_decay = 0.0,
|
||||||
|
sigma_decay = 0.0,
|
||||||
|
|
||||||
es = 'ars',
|
es = 'ars',
|
||||||
ars_lips = false,
|
ars_lips = false,
|
||||||
|
|
20
main.lua
20
main.lua
|
@ -191,6 +191,13 @@ local function prepare_epoch()
|
||||||
print('preparing epoch '..tostring(epoch_i)..'...')
|
print('preparing epoch '..tostring(epoch_i)..'...')
|
||||||
empty(trial_rewards)
|
empty(trial_rewards)
|
||||||
|
|
||||||
|
if cfg.es == 'xnes' then
|
||||||
|
print("sigma:", es.sigma)
|
||||||
|
elseif cfg.es == 'snes' then
|
||||||
|
local sigma_mean, sigma_dev = calc_mean_dev(es.std)
|
||||||
|
print("sigma:", sigma_mean, sigma_dev)
|
||||||
|
end
|
||||||
|
|
||||||
local precision
|
local precision
|
||||||
if cfg.graycode then
|
if cfg.graycode then
|
||||||
precision = (pow(cfg.deviation, 1/-0.51175585) - 8.68297257) / 1.66484392
|
precision = (pow(cfg.deviation, 1/-0.51175585) - 8.68297257) / 1.66484392
|
||||||
|
@ -254,13 +261,6 @@ local function learn_from_epoch()
|
||||||
es:tell(trial_rewards)
|
es:tell(trial_rewards)
|
||||||
end
|
end
|
||||||
|
|
||||||
if cfg.es == 'xnes' then
|
|
||||||
print("sigma:", es.sigma)
|
|
||||||
elseif cfg.es == 'snes' then
|
|
||||||
local sigma_mean, sigma_dev = calc_mean_dev(es.std)
|
|
||||||
print("sigma:", sigma_mean, sigma_dev)
|
|
||||||
end
|
|
||||||
|
|
||||||
local step_mean, step_dev = 0, 0
|
local step_mean, step_dev = 0, 0
|
||||||
--[[ TODO
|
--[[ TODO
|
||||||
local step_mean, step_dev = calc_mean_dev(step)
|
local step_mean, step_dev = calc_mean_dev(step)
|
||||||
|
@ -286,6 +286,12 @@ local function learn_from_epoch()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
if cfg.sigma_decay > 0 and cfg.es == 'snes' then
|
||||||
|
for i, v in ipairs(es.std) do
|
||||||
|
es.std[i] = v * (1 - cfg.sigma_decay)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
es:params(base_params)
|
es:params(base_params)
|
||||||
|
|
||||||
local trial_mean, trial_std = calc_mean_dev(trial_rewards)
|
local trial_mean, trial_std = calc_mean_dev(trial_rewards)
|
||||||
|
|
Loading…
Add table
Reference in a new issue