rename weight* to param* outside of nn.lua
This commit is contained in:
parent
f3fc95404c
commit
e3695bfb84
5 changed files with 30 additions and 30 deletions
4
ars.lua
4
ars.lua
|
@ -64,7 +64,7 @@ function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic)
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||||
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||||
self.weight_rate = base_rate
|
self.param_rate = base_rate
|
||||||
self.sigma_rate = base_rate
|
self.sigma_rate = base_rate
|
||||||
self.covar_rate = base_rate
|
self.covar_rate = base_rate
|
||||||
self.sigma = sigma or 1
|
self.sigma = sigma or 1
|
||||||
|
@ -190,7 +190,7 @@ function Ars:tell(scored, unperturbed_score)
|
||||||
end
|
end
|
||||||
|
|
||||||
for i, v in ipairs(self._params) do
|
for i, v in ipairs(self._params) do
|
||||||
self._params[i] = v + self.weight_rate * step[i]
|
self._params[i] = v + self.param_rate * step[i]
|
||||||
end
|
end
|
||||||
|
|
||||||
self.noise = nil
|
self.noise = nil
|
||||||
|
|
26
config.lua
26
config.lua
|
@ -50,9 +50,9 @@ local common_cfg = {
|
||||||
min_refresh = 0.1, -- for SNES.
|
min_refresh = 0.1, -- for SNES.
|
||||||
|
|
||||||
-- epoch-related rates:
|
-- epoch-related rates:
|
||||||
base_rate = 1.0, -- weight_rate, sigma_rate, and covar_rate will base on this
|
base_rate = 1.0, -- param_rate, sigma_rate, and covar_rate will base on this
|
||||||
-- if you don't specify them individually.
|
-- if you don't specify them individually.
|
||||||
weight_decay = 0.0,
|
param_decay = 0.0,
|
||||||
sigma_decay = 0.0, -- for SNES, xNES.
|
sigma_decay = 0.0, -- for SNES, xNES.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,9 +75,9 @@ if preset == 'snes' then
|
||||||
deviation = 1.0,
|
deviation = 1.0,
|
||||||
min_refresh = 0.2,
|
min_refresh = 0.2,
|
||||||
|
|
||||||
weight_rate = 0.5,
|
param_rate = 0.5,
|
||||||
sigma_rate = 0.1,
|
sigma_rate = 0.1,
|
||||||
weight_decay = 0.02, -- note: multiplied by its std, and weight_rate.
|
param_decay = 0.02, -- note: multiplied by its std, and param_rate.
|
||||||
sigma_decay = 0.01, -- note: multiplied by sigma_rate.
|
sigma_decay = 0.01, -- note: multiplied by sigma_rate.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,9 +108,9 @@ elseif preset == 'snes2' then
|
||||||
|
|
||||||
epoch_trials = 100,
|
epoch_trials = 100,
|
||||||
|
|
||||||
weight_rate = 1.0,
|
param_rate = 1.0,
|
||||||
sigma_rate = 0.01,
|
sigma_rate = 0.01,
|
||||||
weight_decay = 0.02,
|
param_decay = 0.02,
|
||||||
sigma_decay = 0.01,
|
sigma_decay = 0.01,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -140,7 +140,7 @@ elseif preset == 'xnes' then
|
||||||
|
|
||||||
epoch_trials = 50,
|
epoch_trials = 50,
|
||||||
|
|
||||||
weight_rate = 1.0,
|
param_rate = 1.0,
|
||||||
sigma_rate = 0.01,
|
sigma_rate = 0.01,
|
||||||
covar_rate = 0.01,
|
covar_rate = 0.01,
|
||||||
}
|
}
|
||||||
|
@ -171,10 +171,10 @@ elseif preset == 'xnes2' then
|
||||||
|
|
||||||
epoch_trials = 10, --50,
|
epoch_trials = 10, --50,
|
||||||
|
|
||||||
weight_rate = 0.5,
|
param_rate = 0.5,
|
||||||
sigma_rate = 0.04,
|
sigma_rate = 0.04,
|
||||||
covar_rate = 0.04,
|
covar_rate = 0.04,
|
||||||
weight_decay = 0.004,
|
param_decay = 0.004,
|
||||||
sigma_decay = 0.00128,
|
sigma_decay = 0.00128,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -202,8 +202,8 @@ elseif preset == 'ars' then
|
||||||
|
|
||||||
epoch_trials = 25 * 2,
|
epoch_trials = 25 * 2,
|
||||||
|
|
||||||
weight_rate = 1.0,
|
param_rate = 1.0,
|
||||||
weight_decay = 0.0025,
|
param_decay = 0.0025,
|
||||||
}
|
}
|
||||||
|
|
||||||
elseif preset == 'play' then
|
elseif preset == 'play' then
|
||||||
|
@ -224,9 +224,9 @@ end
|
||||||
setmetatable(cfg, {__index=common_cfg}) -- gets overridden later.
|
setmetatable(cfg, {__index=common_cfg}) -- gets overridden later.
|
||||||
|
|
||||||
if cfg.es == 'ars' then
|
if cfg.es == 'ars' then
|
||||||
if cfg.weight_rate == nil then cfg.weight_rate = cfg.base_rate end
|
if cfg.param_rate == nil then cfg.param_rate = cfg.base_rate end
|
||||||
else
|
else
|
||||||
if cfg.weight_rate == nil then cfg.weight_rate = 1.0 end
|
if cfg.param_rate == nil then cfg.param_rate = 1.0 end
|
||||||
end
|
end
|
||||||
if cfg.sigma_rate == nil then cfg.sigma_rate = cfg.base_rate end
|
if cfg.sigma_rate == nil then cfg.sigma_rate = cfg.base_rate end
|
||||||
if cfg.covar_rate == nil then cfg.covar_rate = cfg.sigma_rate end
|
if cfg.covar_rate == nil then cfg.covar_rate = cfg.sigma_rate end
|
||||||
|
|
22
main.lua
22
main.lua
|
@ -283,24 +283,24 @@ local function learn_from_epoch()
|
||||||
es.std[i] = v * (1 - es.sigma_rate * cfg.sigma_decay)
|
es.std[i] = v * (1 - es.sigma_rate * cfg.sigma_decay)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
if cfg.weight_decay > 0 then
|
if cfg.param_decay > 0 then
|
||||||
for i, v in ipairs(base_params) do
|
for i, v in ipairs(base_params) do
|
||||||
base_params[i] = v * (1 - es.weight_rate * cfg.weight_decay * es.std[i])
|
base_params[i] = v * (1 - es.param_rate * cfg.param_decay * es.std[i])
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
elseif cfg.es == 'xnes' then
|
elseif cfg.es == 'xnes' then
|
||||||
if cfg.sigma_decay > 0 then
|
if cfg.sigma_decay > 0 then
|
||||||
es.sigma = es.sigma * (1 - cfg.sigma_decay)
|
es.sigma = es.sigma * (1 - cfg.sigma_decay)
|
||||||
end
|
end
|
||||||
if cfg.weight_decay > 0 then
|
if cfg.param_decay > 0 then
|
||||||
for i, v in ipairs(base_params) do
|
for i, v in ipairs(base_params) do
|
||||||
base_params[i] = v * (1 - es.weight_rate * cfg.weight_decay)
|
base_params[i] = v * (1 - es.param_rate * cfg.param_decay)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
if cfg.weight_decay > 0 then
|
if cfg.param_decay > 0 then
|
||||||
for i, v in ipairs(base_params) do
|
for i, v in ipairs(base_params) do
|
||||||
base_params[i] = v * (1 - cfg.weight_decay)
|
base_params[i] = v * (1 - cfg.param_decay)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -309,7 +309,7 @@ local function learn_from_epoch()
|
||||||
|
|
||||||
local trial_mean, trial_std = calc_mean_dev(trial_rewards)
|
local trial_mean, trial_std = calc_mean_dev(trial_rewards)
|
||||||
local delta_mean, delta_std = calc_mean_dev(delta_rewards)
|
local delta_mean, delta_std = calc_mean_dev(delta_rewards)
|
||||||
local weight_mean, weight_std = calc_mean_dev(base_params)
|
local param_mean, param_std = calc_mean_dev(base_params)
|
||||||
|
|
||||||
log_csv{
|
log_csv{
|
||||||
epoch = epoch_i,
|
epoch = epoch_i,
|
||||||
|
@ -318,8 +318,8 @@ local function learn_from_epoch()
|
||||||
delta_mean = delta_mean,
|
delta_mean = delta_mean,
|
||||||
delta_std = delta_std,
|
delta_std = delta_std,
|
||||||
step_std = step_dev,
|
step_std = step_dev,
|
||||||
weight_mean = weight_mean,
|
weight_mean = param_mean,
|
||||||
weight_std = weight_std,
|
weight_std = param_std,
|
||||||
test_trial = current_cost or 0,
|
test_trial = current_cost or 0,
|
||||||
decisions = decisions_made,
|
decisions = decisions_made,
|
||||||
}
|
}
|
||||||
|
@ -335,7 +335,7 @@ local function learn_from_epoch()
|
||||||
end
|
end
|
||||||
|
|
||||||
else
|
else
|
||||||
print("note: not updating weights in playable mode.")
|
print("note: not updating params in playable mode.")
|
||||||
end
|
end
|
||||||
|
|
||||||
print()
|
print()
|
||||||
|
@ -516,7 +516,7 @@ local function init()
|
||||||
error("Unknown evolution strategy specified: " + tostring(cfg.es))
|
error("Unknown evolution strategy specified: " + tostring(cfg.es))
|
||||||
end
|
end
|
||||||
|
|
||||||
es.weight_rate = cfg.weight_rate
|
es.param_rate = cfg.param_rate
|
||||||
es.sigma_rate = cfg.sigma_rate
|
es.sigma_rate = cfg.sigma_rate
|
||||||
es.covar_rate = cfg.covar_rate
|
es.covar_rate = cfg.covar_rate
|
||||||
es.rate_init = cfg.sigma_rate -- just for SNES?
|
es.rate_init = cfg.sigma_rate -- just for SNES?
|
||||||
|
|
4
snes.lua
4
snes.lua
|
@ -37,7 +37,7 @@ function Snes:init(dims, popsize, base_rate, sigma, antithetic)
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||||
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||||
self.weight_rate = 1.0
|
self.param_rate = 1.0
|
||||||
self.sigma_rate = base_rate
|
self.sigma_rate = base_rate
|
||||||
self.covar_rate = base_rate
|
self.covar_rate = base_rate
|
||||||
self.sigma = sigma or 1
|
self.sigma = sigma or 1
|
||||||
|
@ -251,7 +251,7 @@ function Snes:tell(scored)
|
||||||
end
|
end
|
||||||
|
|
||||||
for i, v in ipairs(self.mean) do
|
for i, v in ipairs(self.mean) do
|
||||||
self.mean[i] = v + self.weight_rate * step[i]
|
self.mean[i] = v + self.param_rate * step[i]
|
||||||
end
|
end
|
||||||
|
|
||||||
local otherwise = {}
|
local otherwise = {}
|
||||||
|
|
4
xnes.lua
4
xnes.lua
|
@ -56,7 +56,7 @@ function Xnes:init(dims, popsize, base_rate, sigma, antithetic)
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||||
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||||
self.weight_rate = 1.0
|
self.param_rate = 1.0
|
||||||
self.sigma_rate = base_rate
|
self.sigma_rate = base_rate
|
||||||
self.covar_rate = base_rate
|
self.covar_rate = base_rate
|
||||||
self.sigma = sigma or 1
|
self.sigma = sigma or 1
|
||||||
|
@ -183,7 +183,7 @@ function Xnes:tell(scored, noise)
|
||||||
end
|
end
|
||||||
|
|
||||||
for i, v in ipairs(self.mean) do
|
for i, v in ipairs(self.mean) do
|
||||||
self.mean[i] = v + self.weight_rate * step[i]
|
self.mean[i] = v + self.param_rate * step[i]
|
||||||
end
|
end
|
||||||
|
|
||||||
self.sigma = self.sigma * exp(self.sigma_rate * 0.5 * g_sigma)
|
self.sigma = self.sigma * exp(self.sigma_rate * 0.5 * g_sigma)
|
||||||
|
|
Loading…
Add table
Reference in a new issue