diff --git a/ars.lua b/ars.lua index 03ac486..ef4df1a 100644 --- a/ars.lua +++ b/ars.lua @@ -64,7 +64,7 @@ function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic) self.dims = dims self.popsize = popsize or 4 + (3 * floor(log(dims))) base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) - self.weight_rate = base_rate + self.param_rate = base_rate self.sigma_rate = base_rate self.covar_rate = base_rate self.sigma = sigma or 1 @@ -190,7 +190,7 @@ function Ars:tell(scored, unperturbed_score) end for i, v in ipairs(self._params) do - self._params[i] = v + self.weight_rate * step[i] + self._params[i] = v + self.param_rate * step[i] end self.noise = nil diff --git a/config.lua b/config.lua index 1420019..8173282 100644 --- a/config.lua +++ b/config.lua @@ -50,9 +50,9 @@ local common_cfg = { min_refresh = 0.1, -- for SNES. -- epoch-related rates: - base_rate = 1.0, -- weight_rate, sigma_rate, and covar_rate will base on this + base_rate = 1.0, -- param_rate, sigma_rate, and covar_rate will base on this -- if you don't specify them individually. - weight_decay = 0.0, + param_decay = 0.0, sigma_decay = 0.0, -- for SNES, xNES. } @@ -75,9 +75,9 @@ if preset == 'snes' then deviation = 1.0, min_refresh = 0.2, - weight_rate = 0.5, + param_rate = 0.5, sigma_rate = 0.1, - weight_decay = 0.02, -- note: multiplied by its std, and weight_rate. + param_decay = 0.02, -- note: multiplied by its std, and param_rate. sigma_decay = 0.01, -- note: multiplied by sigma_rate. } @@ -108,9 +108,9 @@ elseif preset == 'snes2' then epoch_trials = 100, - weight_rate = 1.0, + param_rate = 1.0, sigma_rate = 0.01, - weight_decay = 0.02, + param_decay = 0.02, sigma_decay = 0.01, } @@ -140,7 +140,7 @@ elseif preset == 'xnes' then epoch_trials = 50, - weight_rate = 1.0, + param_rate = 1.0, sigma_rate = 0.01, covar_rate = 0.01, } @@ -171,10 +171,10 @@ elseif preset == 'xnes2' then epoch_trials = 10, --50, - weight_rate = 0.5, + param_rate = 0.5, sigma_rate = 0.04, covar_rate = 0.04, - weight_decay = 0.004, + param_decay = 0.004, sigma_decay = 0.00128, } @@ -202,8 +202,8 @@ elseif preset == 'ars' then epoch_trials = 25 * 2, - weight_rate = 1.0, - weight_decay = 0.0025, + param_rate = 1.0, + param_decay = 0.0025, } elseif preset == 'play' then @@ -224,9 +224,9 @@ end setmetatable(cfg, {__index=common_cfg}) -- gets overridden later. if cfg.es == 'ars' then - if cfg.weight_rate == nil then cfg.weight_rate = cfg.base_rate end + if cfg.param_rate == nil then cfg.param_rate = cfg.base_rate end else - if cfg.weight_rate == nil then cfg.weight_rate = 1.0 end + if cfg.param_rate == nil then cfg.param_rate = 1.0 end end if cfg.sigma_rate == nil then cfg.sigma_rate = cfg.base_rate end if cfg.covar_rate == nil then cfg.covar_rate = cfg.sigma_rate end diff --git a/main.lua b/main.lua index c554d06..7a1b105 100644 --- a/main.lua +++ b/main.lua @@ -283,24 +283,24 @@ local function learn_from_epoch() es.std[i] = v * (1 - es.sigma_rate * cfg.sigma_decay) end end - if cfg.weight_decay > 0 then + if cfg.param_decay > 0 then for i, v in ipairs(base_params) do - base_params[i] = v * (1 - es.weight_rate * cfg.weight_decay * es.std[i]) + base_params[i] = v * (1 - es.param_rate * cfg.param_decay * es.std[i]) end end elseif cfg.es == 'xnes' then if cfg.sigma_decay > 0 then es.sigma = es.sigma * (1 - cfg.sigma_decay) end - if cfg.weight_decay > 0 then + if cfg.param_decay > 0 then for i, v in ipairs(base_params) do - base_params[i] = v * (1 - es.weight_rate * cfg.weight_decay) + base_params[i] = v * (1 - es.param_rate * cfg.param_decay) end end else - if cfg.weight_decay > 0 then + if cfg.param_decay > 0 then for i, v in ipairs(base_params) do - base_params[i] = v * (1 - cfg.weight_decay) + base_params[i] = v * (1 - cfg.param_decay) end end end @@ -309,7 +309,7 @@ local function learn_from_epoch() local trial_mean, trial_std = calc_mean_dev(trial_rewards) local delta_mean, delta_std = calc_mean_dev(delta_rewards) - local weight_mean, weight_std = calc_mean_dev(base_params) + local param_mean, param_std = calc_mean_dev(base_params) log_csv{ epoch = epoch_i, @@ -318,8 +318,8 @@ local function learn_from_epoch() delta_mean = delta_mean, delta_std = delta_std, step_std = step_dev, - weight_mean = weight_mean, - weight_std = weight_std, + weight_mean = param_mean, + weight_std = param_std, test_trial = current_cost or 0, decisions = decisions_made, } @@ -335,7 +335,7 @@ local function learn_from_epoch() end else - print("note: not updating weights in playable mode.") + print("note: not updating params in playable mode.") end print() @@ -516,7 +516,7 @@ local function init() error("Unknown evolution strategy specified: " + tostring(cfg.es)) end - es.weight_rate = cfg.weight_rate + es.param_rate = cfg.param_rate es.sigma_rate = cfg.sigma_rate es.covar_rate = cfg.covar_rate es.rate_init = cfg.sigma_rate -- just for SNES? diff --git a/snes.lua b/snes.lua index 164f6f6..282264a 100644 --- a/snes.lua +++ b/snes.lua @@ -37,7 +37,7 @@ function Snes:init(dims, popsize, base_rate, sigma, antithetic) self.dims = dims self.popsize = popsize or 4 + (3 * floor(log(dims))) base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) - self.weight_rate = 1.0 + self.param_rate = 1.0 self.sigma_rate = base_rate self.covar_rate = base_rate self.sigma = sigma or 1 @@ -251,7 +251,7 @@ function Snes:tell(scored) end for i, v in ipairs(self.mean) do - self.mean[i] = v + self.weight_rate * step[i] + self.mean[i] = v + self.param_rate * step[i] end local otherwise = {} diff --git a/xnes.lua b/xnes.lua index 2be4c14..241a98d 100644 --- a/xnes.lua +++ b/xnes.lua @@ -56,7 +56,7 @@ function Xnes:init(dims, popsize, base_rate, sigma, antithetic) self.dims = dims self.popsize = popsize or 4 + (3 * floor(log(dims))) base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) - self.weight_rate = 1.0 + self.param_rate = 1.0 self.sigma_rate = base_rate self.covar_rate = base_rate self.sigma = sigma or 1 @@ -183,7 +183,7 @@ function Xnes:tell(scored, noise) end for i, v in ipairs(self.mean) do - self.mean[i] = v + self.weight_rate * step[i] + self.mean[i] = v + self.param_rate * step[i] end self.sigma = self.sigma * exp(self.sigma_rate * 0.5 * g_sigma)