rename weight* to param* outside of nn.lua

This commit is contained in:
Connor Olding 2018-06-16 00:33:11 +02:00
parent f3fc95404c
commit e3695bfb84
5 changed files with 30 additions and 30 deletions

View File

@ -64,7 +64,7 @@ function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic)
self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims)))
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
self.weight_rate = base_rate
self.param_rate = base_rate
self.sigma_rate = base_rate
self.covar_rate = base_rate
self.sigma = sigma or 1
@ -190,7 +190,7 @@ function Ars:tell(scored, unperturbed_score)
end
for i, v in ipairs(self._params) do
self._params[i] = v + self.weight_rate * step[i]
self._params[i] = v + self.param_rate * step[i]
end
self.noise = nil

View File

@ -50,9 +50,9 @@ local common_cfg = {
min_refresh = 0.1, -- for SNES.
-- epoch-related rates:
base_rate = 1.0, -- weight_rate, sigma_rate, and covar_rate will base on this
base_rate = 1.0, -- param_rate, sigma_rate, and covar_rate will base on this
-- if you don't specify them individually.
weight_decay = 0.0,
param_decay = 0.0,
sigma_decay = 0.0, -- for SNES, xNES.
}
@ -75,9 +75,9 @@ if preset == 'snes' then
deviation = 1.0,
min_refresh = 0.2,
weight_rate = 0.5,
param_rate = 0.5,
sigma_rate = 0.1,
weight_decay = 0.02, -- note: multiplied by its std, and weight_rate.
param_decay = 0.02, -- note: multiplied by its std, and param_rate.
sigma_decay = 0.01, -- note: multiplied by sigma_rate.
}
@ -108,9 +108,9 @@ elseif preset == 'snes2' then
epoch_trials = 100,
weight_rate = 1.0,
param_rate = 1.0,
sigma_rate = 0.01,
weight_decay = 0.02,
param_decay = 0.02,
sigma_decay = 0.01,
}
@ -140,7 +140,7 @@ elseif preset == 'xnes' then
epoch_trials = 50,
weight_rate = 1.0,
param_rate = 1.0,
sigma_rate = 0.01,
covar_rate = 0.01,
}
@ -171,10 +171,10 @@ elseif preset == 'xnes2' then
epoch_trials = 10, --50,
weight_rate = 0.5,
param_rate = 0.5,
sigma_rate = 0.04,
covar_rate = 0.04,
weight_decay = 0.004,
param_decay = 0.004,
sigma_decay = 0.00128,
}
@ -202,8 +202,8 @@ elseif preset == 'ars' then
epoch_trials = 25 * 2,
weight_rate = 1.0,
weight_decay = 0.0025,
param_rate = 1.0,
param_decay = 0.0025,
}
elseif preset == 'play' then
@ -224,9 +224,9 @@ end
setmetatable(cfg, {__index=common_cfg}) -- gets overridden later.
if cfg.es == 'ars' then
if cfg.weight_rate == nil then cfg.weight_rate = cfg.base_rate end
if cfg.param_rate == nil then cfg.param_rate = cfg.base_rate end
else
if cfg.weight_rate == nil then cfg.weight_rate = 1.0 end
if cfg.param_rate == nil then cfg.param_rate = 1.0 end
end
if cfg.sigma_rate == nil then cfg.sigma_rate = cfg.base_rate end
if cfg.covar_rate == nil then cfg.covar_rate = cfg.sigma_rate end

View File

@ -283,24 +283,24 @@ local function learn_from_epoch()
es.std[i] = v * (1 - es.sigma_rate * cfg.sigma_decay)
end
end
if cfg.weight_decay > 0 then
if cfg.param_decay > 0 then
for i, v in ipairs(base_params) do
base_params[i] = v * (1 - es.weight_rate * cfg.weight_decay * es.std[i])
base_params[i] = v * (1 - es.param_rate * cfg.param_decay * es.std[i])
end
end
elseif cfg.es == 'xnes' then
if cfg.sigma_decay > 0 then
es.sigma = es.sigma * (1 - cfg.sigma_decay)
end
if cfg.weight_decay > 0 then
if cfg.param_decay > 0 then
for i, v in ipairs(base_params) do
base_params[i] = v * (1 - es.weight_rate * cfg.weight_decay)
base_params[i] = v * (1 - es.param_rate * cfg.param_decay)
end
end
else
if cfg.weight_decay > 0 then
if cfg.param_decay > 0 then
for i, v in ipairs(base_params) do
base_params[i] = v * (1 - cfg.weight_decay)
base_params[i] = v * (1 - cfg.param_decay)
end
end
end
@ -309,7 +309,7 @@ local function learn_from_epoch()
local trial_mean, trial_std = calc_mean_dev(trial_rewards)
local delta_mean, delta_std = calc_mean_dev(delta_rewards)
local weight_mean, weight_std = calc_mean_dev(base_params)
local param_mean, param_std = calc_mean_dev(base_params)
log_csv{
epoch = epoch_i,
@ -318,8 +318,8 @@ local function learn_from_epoch()
delta_mean = delta_mean,
delta_std = delta_std,
step_std = step_dev,
weight_mean = weight_mean,
weight_std = weight_std,
weight_mean = param_mean,
weight_std = param_std,
test_trial = current_cost or 0,
decisions = decisions_made,
}
@ -335,7 +335,7 @@ local function learn_from_epoch()
end
else
print("note: not updating weights in playable mode.")
print("note: not updating params in playable mode.")
end
print()
@ -516,7 +516,7 @@ local function init()
error("Unknown evolution strategy specified: " + tostring(cfg.es))
end
es.weight_rate = cfg.weight_rate
es.param_rate = cfg.param_rate
es.sigma_rate = cfg.sigma_rate
es.covar_rate = cfg.covar_rate
es.rate_init = cfg.sigma_rate -- just for SNES?

View File

@ -37,7 +37,7 @@ function Snes:init(dims, popsize, base_rate, sigma, antithetic)
self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims)))
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
self.weight_rate = 1.0
self.param_rate = 1.0
self.sigma_rate = base_rate
self.covar_rate = base_rate
self.sigma = sigma or 1
@ -251,7 +251,7 @@ function Snes:tell(scored)
end
for i, v in ipairs(self.mean) do
self.mean[i] = v + self.weight_rate * step[i]
self.mean[i] = v + self.param_rate * step[i]
end
local otherwise = {}

View File

@ -56,7 +56,7 @@ function Xnes:init(dims, popsize, base_rate, sigma, antithetic)
self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims)))
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
self.weight_rate = 1.0
self.param_rate = 1.0
self.sigma_rate = base_rate
self.covar_rate = base_rate
self.sigma = sigma or 1
@ -183,7 +183,7 @@ function Xnes:tell(scored, noise)
end
for i, v in ipairs(self.mean) do
self.mean[i] = v + self.weight_rate * step[i]
self.mean[i] = v + self.param_rate * step[i]
end
self.sigma = self.sigma * exp(self.sigma_rate * 0.5 * g_sigma)