From dc8969469d945c5e5bc181b65259ad76ae18b266 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Thu, 28 Jun 2018 11:03:28 +0200 Subject: [PATCH] move param/sigma decay into es methods --- ars.lua | 8 ++++++++ main.lua | 33 ++------------------------------- snes.lua | 13 +++++++++++++ xnes.lua | 11 +++++++++++ 4 files changed, 34 insertions(+), 31 deletions(-) diff --git a/ars.lua b/ars.lua index 6828b35..52e5bd1 100644 --- a/ars.lua +++ b/ars.lua @@ -96,6 +96,14 @@ function Ars:params(new_params) return self._params end +function Ars:decay(param_decay, sigma_decay) + if param_decay > 0 then + for i, v in ipairs(self._params) do + self._params[i] = v * (1 - self.param_rate * param_decay * self.sigma) + end + end +end + function Ars:ask(graycode) local asked = {} local noise = {} diff --git a/main.lua b/main.lua index f5efd94..5b9078d 100644 --- a/main.lua +++ b/main.lua @@ -274,39 +274,10 @@ local function learn_from_epoch() print("step mean:", step_mean) print("step stddev:", step_dev) + es:decay(cfg.param_decay, cfg.sigma_decay) + base_params = es:params() - -- TODO: move this all to es:decay methods. - if cfg.es == 'snes' then - if cfg.sigma_decay > 0 then - for i, v in ipairs(es.std) do - es.std[i] = v * (1 - es.sigma_rate * cfg.sigma_decay) - end - end - if cfg.param_decay > 0 then - for i, v in ipairs(base_params) do - base_params[i] = v * (1 - es.param_rate * cfg.param_decay * es.std[i]) - end - end - elseif cfg.es == 'xnes' then - if cfg.sigma_decay > 0 then - es.sigma = es.sigma * (1 - cfg.sigma_decay) - end - if cfg.param_decay > 0 then - for i, v in ipairs(base_params) do - base_params[i] = v * (1 - es.param_rate * cfg.param_decay * es.sigma) - end - end - else - if cfg.param_decay > 0 then - for i, v in ipairs(base_params) do - base_params[i] = v * (1 - cfg.param_rate * cfg.param_decay * cfg.deviation) - end - end - end - - es:params(base_params) - local trial_mean, trial_std = calc_mean_dev(trial_rewards) local delta_mean, delta_std = calc_mean_dev(delta_rewards) local param_mean, param_std = calc_mean_dev(base_params) diff --git a/snes.lua b/snes.lua index 7110b62..f2532d5 100644 --- a/snes.lua +++ b/snes.lua @@ -68,6 +68,19 @@ function Snes:params(new_mean) return self.mean end +function Snes:decay(param_decay, sigma_decay) + if sigma_decay > 0 then + for i, v in ipairs(self.std) do + self.std[i] = v * (1 - self.sigma_rate * sigma_decay) + end + end + if param_decay > 0 then + for i, v in ipairs(self.mean) do + self.mean[i] = v * (1 - self.param_rate * param_decay * self.std[i]) + end + end +end + function Snes:ask_once(asked, noise) asked = asked or {} noise = noise or {} diff --git a/xnes.lua b/xnes.lua index 241a98d..1e45a8f 100644 --- a/xnes.lua +++ b/xnes.lua @@ -80,6 +80,17 @@ function Xnes:params(new_mean) return self.mean end +function Xnes:decay(param_decay, sigma_decay) + if sigma_decay > 0 then + self.sigma = self.sigma * (1 - sigma_decay) + end + if param_decay > 0 then + for i, v in ipairs(self.mean) do + self.mean[i] = v * (1 - self.param_rate * param_decay * self.sigma) + end + end +end + function Xnes:ask_once(asked, noise) asked = asked or zeros(self.dims) noise = noise or {}