move param/sigma decay into es methods
This commit is contained in:
parent
450bd70d99
commit
dc8969469d
4 changed files with 34 additions and 31 deletions
8
ars.lua
8
ars.lua
|
@ -96,6 +96,14 @@ function Ars:params(new_params)
|
||||||
return self._params
|
return self._params
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Ars:decay(param_decay, sigma_decay)
|
||||||
|
if param_decay > 0 then
|
||||||
|
for i, v in ipairs(self._params) do
|
||||||
|
self._params[i] = v * (1 - self.param_rate * param_decay * self.sigma)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
function Ars:ask(graycode)
|
function Ars:ask(graycode)
|
||||||
local asked = {}
|
local asked = {}
|
||||||
local noise = {}
|
local noise = {}
|
||||||
|
|
33
main.lua
33
main.lua
|
@ -274,39 +274,10 @@ local function learn_from_epoch()
|
||||||
print("step mean:", step_mean)
|
print("step mean:", step_mean)
|
||||||
print("step stddev:", step_dev)
|
print("step stddev:", step_dev)
|
||||||
|
|
||||||
|
es:decay(cfg.param_decay, cfg.sigma_decay)
|
||||||
|
|
||||||
base_params = es:params()
|
base_params = es:params()
|
||||||
|
|
||||||
-- TODO: move this all to es:decay methods.
|
|
||||||
if cfg.es == 'snes' then
|
|
||||||
if cfg.sigma_decay > 0 then
|
|
||||||
for i, v in ipairs(es.std) do
|
|
||||||
es.std[i] = v * (1 - es.sigma_rate * cfg.sigma_decay)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
if cfg.param_decay > 0 then
|
|
||||||
for i, v in ipairs(base_params) do
|
|
||||||
base_params[i] = v * (1 - es.param_rate * cfg.param_decay * es.std[i])
|
|
||||||
end
|
|
||||||
end
|
|
||||||
elseif cfg.es == 'xnes' then
|
|
||||||
if cfg.sigma_decay > 0 then
|
|
||||||
es.sigma = es.sigma * (1 - cfg.sigma_decay)
|
|
||||||
end
|
|
||||||
if cfg.param_decay > 0 then
|
|
||||||
for i, v in ipairs(base_params) do
|
|
||||||
base_params[i] = v * (1 - es.param_rate * cfg.param_decay * es.sigma)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
else
|
|
||||||
if cfg.param_decay > 0 then
|
|
||||||
for i, v in ipairs(base_params) do
|
|
||||||
base_params[i] = v * (1 - cfg.param_rate * cfg.param_decay * cfg.deviation)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
es:params(base_params)
|
|
||||||
|
|
||||||
local trial_mean, trial_std = calc_mean_dev(trial_rewards)
|
local trial_mean, trial_std = calc_mean_dev(trial_rewards)
|
||||||
local delta_mean, delta_std = calc_mean_dev(delta_rewards)
|
local delta_mean, delta_std = calc_mean_dev(delta_rewards)
|
||||||
local param_mean, param_std = calc_mean_dev(base_params)
|
local param_mean, param_std = calc_mean_dev(base_params)
|
||||||
|
|
13
snes.lua
13
snes.lua
|
@ -68,6 +68,19 @@ function Snes:params(new_mean)
|
||||||
return self.mean
|
return self.mean
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Snes:decay(param_decay, sigma_decay)
|
||||||
|
if sigma_decay > 0 then
|
||||||
|
for i, v in ipairs(self.std) do
|
||||||
|
self.std[i] = v * (1 - self.sigma_rate * sigma_decay)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
if param_decay > 0 then
|
||||||
|
for i, v in ipairs(self.mean) do
|
||||||
|
self.mean[i] = v * (1 - self.param_rate * param_decay * self.std[i])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
function Snes:ask_once(asked, noise)
|
function Snes:ask_once(asked, noise)
|
||||||
asked = asked or {}
|
asked = asked or {}
|
||||||
noise = noise or {}
|
noise = noise or {}
|
||||||
|
|
11
xnes.lua
11
xnes.lua
|
@ -80,6 +80,17 @@ function Xnes:params(new_mean)
|
||||||
return self.mean
|
return self.mean
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Xnes:decay(param_decay, sigma_decay)
|
||||||
|
if sigma_decay > 0 then
|
||||||
|
self.sigma = self.sigma * (1 - sigma_decay)
|
||||||
|
end
|
||||||
|
if param_decay > 0 then
|
||||||
|
for i, v in ipairs(self.mean) do
|
||||||
|
self.mean[i] = v * (1 - self.param_rate * param_decay * self.sigma)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
function Xnes:ask_once(asked, noise)
|
function Xnes:ask_once(asked, noise)
|
||||||
asked = asked or zeros(self.dims)
|
asked = asked or zeros(self.dims)
|
||||||
noise = noise or {}
|
noise = noise or {}
|
||||||
|
|
Loading…
Add table
Reference in a new issue