From f512f8ac3ad46077868a811b29569cbafd2df382 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Thu, 14 Jun 2018 22:25:54 +0200 Subject: [PATCH] add sigma decay to xNES --- config.lua | 2 +- main.lua | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/config.lua b/config.lua index 36b7cb3..b9307f0 100644 --- a/config.lua +++ b/config.lua @@ -52,7 +52,7 @@ local common_cfg = { learning_rate = 1.0, mean_adapt = 1.0, -- for SNES, xNES. weight_decay = 0.0, - sigma_decay = 0.0, -- for SNES. + sigma_decay = 0.0, -- for SNES, xNES. } local cfg diff --git a/main.lua b/main.lua index 050559b..068d0bc 100644 --- a/main.lua +++ b/main.lua @@ -287,6 +287,15 @@ local function learn_from_epoch() base_params[i] = v * (1 - cfg.mean_adapt * cfg.weight_decay * es.std[i]) end end + elseif cfg.es == 'xnes' then + if cfg.sigma_decay > 0 then + es.sigma = es.sigma * (1 - cfg.sigma_decay) + end + if cfg.weight_decay > 0 then + for i, v in ipairs(base_params) do + base_params[i] = v * (1 - cfg.mean_adapt * cfg.weight_decay) + end + end else if cfg.weight_decay > 0 then for i, v in ipairs(base_params) do