move param/sigma decay into es methods

This commit is contained in:
Connor Olding 2018-06-28 11:03:28 +02:00
parent 450bd70d99
commit dc8969469d
4 changed files with 34 additions and 31 deletions

View file

@ -96,6 +96,14 @@ function Ars:params(new_params)
return self._params
end
function Ars:decay(param_decay, sigma_decay)
if param_decay > 0 then
for i, v in ipairs(self._params) do
self._params[i] = v * (1 - self.param_rate * param_decay * self.sigma)
end
end
end
function Ars:ask(graycode)
local asked = {}
local noise = {}

View file

@ -274,39 +274,10 @@ local function learn_from_epoch()
print("step mean:", step_mean)
print("step stddev:", step_dev)
es:decay(cfg.param_decay, cfg.sigma_decay)
base_params = es:params()
-- TODO: move this all to es:decay methods.
if cfg.es == 'snes' then
if cfg.sigma_decay > 0 then
for i, v in ipairs(es.std) do
es.std[i] = v * (1 - es.sigma_rate * cfg.sigma_decay)
end
end
if cfg.param_decay > 0 then
for i, v in ipairs(base_params) do
base_params[i] = v * (1 - es.param_rate * cfg.param_decay * es.std[i])
end
end
elseif cfg.es == 'xnes' then
if cfg.sigma_decay > 0 then
es.sigma = es.sigma * (1 - cfg.sigma_decay)
end
if cfg.param_decay > 0 then
for i, v in ipairs(base_params) do
base_params[i] = v * (1 - es.param_rate * cfg.param_decay * es.sigma)
end
end
else
if cfg.param_decay > 0 then
for i, v in ipairs(base_params) do
base_params[i] = v * (1 - cfg.param_rate * cfg.param_decay * cfg.deviation)
end
end
end
es:params(base_params)
local trial_mean, trial_std = calc_mean_dev(trial_rewards)
local delta_mean, delta_std = calc_mean_dev(delta_rewards)
local param_mean, param_std = calc_mean_dev(base_params)

View file

@ -68,6 +68,19 @@ function Snes:params(new_mean)
return self.mean
end
function Snes:decay(param_decay, sigma_decay)
if sigma_decay > 0 then
for i, v in ipairs(self.std) do
self.std[i] = v * (1 - self.sigma_rate * sigma_decay)
end
end
if param_decay > 0 then
for i, v in ipairs(self.mean) do
self.mean[i] = v * (1 - self.param_rate * param_decay * self.std[i])
end
end
end
function Snes:ask_once(asked, noise)
asked = asked or {}
noise = noise or {}

View file

@ -80,6 +80,17 @@ function Xnes:params(new_mean)
return self.mean
end
function Xnes:decay(param_decay, sigma_decay)
if sigma_decay > 0 then
self.sigma = self.sigma * (1 - sigma_decay)
end
if param_decay > 0 then
for i, v in ipairs(self.mean) do
self.mean[i] = v * (1 - self.param_rate * param_decay * self.sigma)
end
end
end
function Xnes:ask_once(asked, noise)
asked = asked or zeros(self.dims)
noise = noise or {}