move param/sigma decay into es methods
This commit is contained in:
parent
450bd70d99
commit
dc8969469d
4 changed files with 34 additions and 31 deletions
8
ars.lua
8
ars.lua
|
@ -96,6 +96,14 @@ function Ars:params(new_params)
|
|||
return self._params
|
||||
end
|
||||
|
||||
function Ars:decay(param_decay, sigma_decay)
|
||||
if param_decay > 0 then
|
||||
for i, v in ipairs(self._params) do
|
||||
self._params[i] = v * (1 - self.param_rate * param_decay * self.sigma)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function Ars:ask(graycode)
|
||||
local asked = {}
|
||||
local noise = {}
|
||||
|
|
33
main.lua
33
main.lua
|
@ -274,39 +274,10 @@ local function learn_from_epoch()
|
|||
print("step mean:", step_mean)
|
||||
print("step stddev:", step_dev)
|
||||
|
||||
es:decay(cfg.param_decay, cfg.sigma_decay)
|
||||
|
||||
base_params = es:params()
|
||||
|
||||
-- TODO: move this all to es:decay methods.
|
||||
if cfg.es == 'snes' then
|
||||
if cfg.sigma_decay > 0 then
|
||||
for i, v in ipairs(es.std) do
|
||||
es.std[i] = v * (1 - es.sigma_rate * cfg.sigma_decay)
|
||||
end
|
||||
end
|
||||
if cfg.param_decay > 0 then
|
||||
for i, v in ipairs(base_params) do
|
||||
base_params[i] = v * (1 - es.param_rate * cfg.param_decay * es.std[i])
|
||||
end
|
||||
end
|
||||
elseif cfg.es == 'xnes' then
|
||||
if cfg.sigma_decay > 0 then
|
||||
es.sigma = es.sigma * (1 - cfg.sigma_decay)
|
||||
end
|
||||
if cfg.param_decay > 0 then
|
||||
for i, v in ipairs(base_params) do
|
||||
base_params[i] = v * (1 - es.param_rate * cfg.param_decay * es.sigma)
|
||||
end
|
||||
end
|
||||
else
|
||||
if cfg.param_decay > 0 then
|
||||
for i, v in ipairs(base_params) do
|
||||
base_params[i] = v * (1 - cfg.param_rate * cfg.param_decay * cfg.deviation)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
es:params(base_params)
|
||||
|
||||
local trial_mean, trial_std = calc_mean_dev(trial_rewards)
|
||||
local delta_mean, delta_std = calc_mean_dev(delta_rewards)
|
||||
local param_mean, param_std = calc_mean_dev(base_params)
|
||||
|
|
13
snes.lua
13
snes.lua
|
@ -68,6 +68,19 @@ function Snes:params(new_mean)
|
|||
return self.mean
|
||||
end
|
||||
|
||||
function Snes:decay(param_decay, sigma_decay)
|
||||
if sigma_decay > 0 then
|
||||
for i, v in ipairs(self.std) do
|
||||
self.std[i] = v * (1 - self.sigma_rate * sigma_decay)
|
||||
end
|
||||
end
|
||||
if param_decay > 0 then
|
||||
for i, v in ipairs(self.mean) do
|
||||
self.mean[i] = v * (1 - self.param_rate * param_decay * self.std[i])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function Snes:ask_once(asked, noise)
|
||||
asked = asked or {}
|
||||
noise = noise or {}
|
||||
|
|
11
xnes.lua
11
xnes.lua
|
@ -80,6 +80,17 @@ function Xnes:params(new_mean)
|
|||
return self.mean
|
||||
end
|
||||
|
||||
function Xnes:decay(param_decay, sigma_decay)
|
||||
if sigma_decay > 0 then
|
||||
self.sigma = self.sigma * (1 - sigma_decay)
|
||||
end
|
||||
if param_decay > 0 then
|
||||
for i, v in ipairs(self.mean) do
|
||||
self.mean[i] = v * (1 - self.param_rate * param_decay * self.sigma)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function Xnes:ask_once(asked, noise)
|
||||
asked = asked or zeros(self.dims)
|
||||
noise = noise or {}
|
||||
|
|
Loading…
Reference in a new issue