overhaul learning rates:
- rename mean_adapt to weight_rate - sigma and covar update rates can be specified separately (sigma_rate, covar_rate) - base decays on current rates instead of initially configured rates (this might break stuff) - base_rate takes the place of learning_rate
This commit is contained in:
parent
474bac45b8
commit
f3fc95404c
5 changed files with 64 additions and 44 deletions
9
ars.lua
9
ars.lua
|
@ -60,10 +60,13 @@ local function kinda_lipschitz(dir, pos, neg, mid)
|
|||
return max(l0, l1) / (2 * dev)
|
||||
end
|
||||
|
||||
function Ars:init(dims, popsize, poptop, learning_rate, sigma, antithetic)
|
||||
function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic)
|
||||
self.dims = dims
|
||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||
self.weight_rate = base_rate
|
||||
self.sigma_rate = base_rate
|
||||
self.covar_rate = base_rate
|
||||
self.sigma = sigma or 1
|
||||
self.antithetic = antithetic and true or false
|
||||
|
||||
|
@ -187,7 +190,7 @@ function Ars:tell(scored, unperturbed_score)
|
|||
end
|
||||
|
||||
for i, v in ipairs(self._params) do
|
||||
self._params[i] = v + self.learning_rate * step[i]
|
||||
self._params[i] = v + self.weight_rate * step[i]
|
||||
end
|
||||
|
||||
self.noise = nil
|
||||
|
|
37
config.lua
37
config.lua
|
@ -50,8 +50,8 @@ local common_cfg = {
|
|||
min_refresh = 0.1, -- for SNES.
|
||||
|
||||
-- epoch-related rates:
|
||||
learning_rate = 1.0,
|
||||
mean_adapt = 1.0, -- for SNES, xNES.
|
||||
base_rate = 1.0, -- weight_rate, sigma_rate, and covar_rate will base on this
|
||||
-- if you don't specify them individually.
|
||||
weight_decay = 0.0,
|
||||
sigma_decay = 0.0, -- for SNES, xNES.
|
||||
}
|
||||
|
@ -75,10 +75,10 @@ if preset == 'snes' then
|
|||
deviation = 1.0,
|
||||
min_refresh = 0.2,
|
||||
|
||||
learning_rate = 0.1, -- TODO: rename to learn_primary or something.
|
||||
mean_adapt = 0.5, -- TODO: rename to learn_secondary or something.
|
||||
weight_decay = 0.02, -- note: multiplied by its std, and mean_adapt.
|
||||
sigma_decay = 0.01, -- note: multiplied by learning_rate.
|
||||
weight_rate = 0.5,
|
||||
sigma_rate = 0.1,
|
||||
weight_decay = 0.02, -- note: multiplied by its std, and weight_rate.
|
||||
sigma_decay = 0.01, -- note: multiplied by sigma_rate.
|
||||
}
|
||||
|
||||
elseif preset == 'snes2' then
|
||||
|
@ -108,8 +108,8 @@ elseif preset == 'snes2' then
|
|||
|
||||
epoch_trials = 100,
|
||||
|
||||
learning_rate = 0.01,
|
||||
mean_adapt = 1.0,
|
||||
weight_rate = 1.0,
|
||||
sigma_rate = 0.01,
|
||||
weight_decay = 0.02,
|
||||
sigma_decay = 0.01,
|
||||
}
|
||||
|
@ -140,7 +140,9 @@ elseif preset == 'xnes' then
|
|||
|
||||
epoch_trials = 50,
|
||||
|
||||
learning_rate = 0.01,
|
||||
weight_rate = 1.0,
|
||||
sigma_rate = 0.01,
|
||||
covar_rate = 0.01,
|
||||
}
|
||||
|
||||
elseif preset == 'xnes2' then
|
||||
|
@ -169,8 +171,9 @@ elseif preset == 'xnes2' then
|
|||
|
||||
epoch_trials = 10, --50,
|
||||
|
||||
learning_rate = 0.04,
|
||||
mean_adapt = 0.5,
|
||||
weight_rate = 0.5,
|
||||
sigma_rate = 0.04,
|
||||
covar_rate = 0.04,
|
||||
weight_decay = 0.004,
|
||||
sigma_decay = 0.00128,
|
||||
}
|
||||
|
@ -199,7 +202,7 @@ elseif preset == 'ars' then
|
|||
|
||||
epoch_trials = 25 * 2,
|
||||
|
||||
learning_rate = 1.0,
|
||||
weight_rate = 1.0,
|
||||
weight_decay = 0.0025,
|
||||
}
|
||||
|
||||
|
@ -218,6 +221,16 @@ end
|
|||
-- TODO: so, uhh..
|
||||
-- what happens when playback_mode is true but unperturbed_trial is false?
|
||||
|
||||
setmetatable(cfg, {__index=common_cfg}) -- gets overridden later.
|
||||
|
||||
if cfg.es == 'ars' then
|
||||
if cfg.weight_rate == nil then cfg.weight_rate = cfg.base_rate end
|
||||
else
|
||||
if cfg.weight_rate == nil then cfg.weight_rate = 1.0 end
|
||||
end
|
||||
if cfg.sigma_rate == nil then cfg.sigma_rate = cfg.base_rate end
|
||||
if cfg.covar_rate == nil then cfg.covar_rate = cfg.sigma_rate end
|
||||
|
||||
setmetatable(cfg, {
|
||||
__index = function(t, n)
|
||||
if common_cfg[n] ~= nil then return common_cfg[n] end
|
||||
|
|
21
main.lua
21
main.lua
|
@ -276,15 +276,16 @@ local function learn_from_epoch()
|
|||
|
||||
base_params = es:params()
|
||||
|
||||
-- TODO: move this all to es:decay methods.
|
||||
if cfg.es == 'snes' then
|
||||
if cfg.sigma_decay > 0 then
|
||||
for i, v in ipairs(es.std) do
|
||||
es.std[i] = v * (1 - cfg.learning_rate * cfg.sigma_decay)
|
||||
es.std[i] = v * (1 - es.sigma_rate * cfg.sigma_decay)
|
||||
end
|
||||
end
|
||||
if cfg.weight_decay > 0 then
|
||||
for i, v in ipairs(base_params) do
|
||||
base_params[i] = v * (1 - cfg.mean_adapt * cfg.weight_decay * es.std[i])
|
||||
base_params[i] = v * (1 - es.weight_rate * cfg.weight_decay * es.std[i])
|
||||
end
|
||||
end
|
||||
elseif cfg.es == 'xnes' then
|
||||
|
@ -293,7 +294,7 @@ local function learn_from_epoch()
|
|||
end
|
||||
if cfg.weight_decay > 0 then
|
||||
for i, v in ipairs(base_params) do
|
||||
base_params[i] = v * (1 - cfg.mean_adapt * cfg.weight_decay)
|
||||
base_params[i] = v * (1 - es.weight_rate * cfg.weight_decay)
|
||||
end
|
||||
end
|
||||
else
|
||||
|
@ -494,14 +495,11 @@ local function init()
|
|||
-- if you get an out of memory error, you can't use xNES. sorry!
|
||||
-- maybe there'll be a patch for FCEUX in the future.
|
||||
es = xnes.Xnes(network.n_param, cfg.epoch_trials,
|
||||
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
||||
-- TODO: clean this up into an interface:
|
||||
es.mean_adapt = cfg.mean_adapt
|
||||
cfg.base_rate, cfg.deviation, cfg.negate_trials)
|
||||
elseif cfg.es == 'snes' then
|
||||
es = snes.Snes(network.n_param, cfg.epoch_trials,
|
||||
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
||||
cfg.base_rate, cfg.deviation, cfg.negate_trials)
|
||||
-- TODO: clean this up into an interface:
|
||||
es.mean_adapt = cfg.mean_adapt
|
||||
es.min_refresh = cfg.min_refresh
|
||||
|
||||
if exists(std_fn) then
|
||||
|
@ -513,11 +511,16 @@ local function init()
|
|||
end
|
||||
elseif cfg.es == 'ars' then
|
||||
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
||||
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
||||
cfg.base_rate, cfg.deviation, cfg.negate_trials)
|
||||
else
|
||||
error("Unknown evolution strategy specified: " + tostring(cfg.es))
|
||||
end
|
||||
|
||||
es.weight_rate = cfg.weight_rate
|
||||
es.sigma_rate = cfg.sigma_rate
|
||||
es.covar_rate = cfg.covar_rate
|
||||
es.rate_init = cfg.sigma_rate -- just for SNES?
|
||||
|
||||
es:params(network:collect())
|
||||
end
|
||||
|
||||
|
|
25
snes.lua
25
snes.lua
|
@ -32,24 +32,25 @@ local weighted_mann_whitney = util.weighted_mann_whitney
|
|||
|
||||
local Snes = Base:extend()
|
||||
|
||||
function Snes:init(dims, popsize, learning_rate, sigma, antithetic)
|
||||
function Snes:init(dims, popsize, base_rate, sigma, antithetic)
|
||||
-- heuristic borrowed from CMA-ES:
|
||||
self.dims = dims
|
||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||
self.weight_rate = 1.0
|
||||
self.sigma_rate = base_rate
|
||||
self.covar_rate = base_rate
|
||||
self.sigma = sigma or 1
|
||||
self.antithetic = antithetic and true or false
|
||||
|
||||
if self.antithetic then self.popsize = self.popsize * 2 end
|
||||
|
||||
self.rate_init = self.learning_rate
|
||||
self.rate_init = self.sigma_rate
|
||||
|
||||
self.mean = zeros{dims}
|
||||
self.std = zeros{dims}
|
||||
for i=1, self.dims do self.std[i] = self.sigma end
|
||||
|
||||
self.mean_adapt = 1.0
|
||||
|
||||
self.old_asked = {}
|
||||
self.old_noise = {}
|
||||
self.old_score = {}
|
||||
|
@ -250,15 +251,15 @@ function Snes:tell(scored)
|
|||
end
|
||||
|
||||
for i, v in ipairs(self.mean) do
|
||||
self.mean[i] = v + self.mean_adapt * step[i]
|
||||
self.mean[i] = v + self.weight_rate * step[i]
|
||||
end
|
||||
|
||||
local otherwise = {}
|
||||
self.std_old = {}
|
||||
for i, v in ipairs(self.std) do
|
||||
self.std_old[i] = v
|
||||
self.std[i] = v * exp(self.learning_rate * 0.5 * g_std[i])
|
||||
otherwise[i] = v * exp(self.learning_rate * 0.75 * g_std[i])
|
||||
self.std[i] = v * exp(self.sigma_rate * 0.5 * g_std[i])
|
||||
otherwise[i] = v * exp(self.sigma_rate * 0.75 * g_std[i])
|
||||
end
|
||||
|
||||
self:adapt(asked, otherwise, utility)
|
||||
|
@ -283,11 +284,11 @@ function Snes:adapt(asked, otherwise, qualities)
|
|||
--print("p:", p)
|
||||
|
||||
if p < 0.5 - 1 / (3 * (self.dims + 1)) then
|
||||
self.learning_rate = 0.9 * self.learning_rate + 0.1 * self.rate_init
|
||||
print("learning rate -:", self.learning_rate)
|
||||
self.sigma_rate = 0.9 * self.sigma_rate + 0.1 * self.rate_init
|
||||
print("learning rate -:", self.sigma_rate)
|
||||
else
|
||||
self.learning_rate = min(1.1 * self.learning_rate, 1)
|
||||
print("learning rate +:", self.learning_rate)
|
||||
self.sigma_rate = min(1.1 * self.sigma_rate, 1)
|
||||
print("learning rate +:", self.sigma_rate)
|
||||
end
|
||||
end
|
||||
|
||||
|
|
16
xnes.lua
16
xnes.lua
|
@ -51,11 +51,14 @@ local function make_covars(dims, sigma, out)
|
|||
return covars
|
||||
end
|
||||
|
||||
function Xnes:init(dims, popsize, learning_rate, sigma, antithetic)
|
||||
function Xnes:init(dims, popsize, base_rate, sigma, antithetic)
|
||||
-- heuristic borrowed from CMA-ES:
|
||||
self.dims = dims
|
||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||
self.weight_rate = 1.0
|
||||
self.sigma_rate = base_rate
|
||||
self.covar_rate = base_rate
|
||||
self.sigma = sigma or 1
|
||||
self.antithetic = antithetic and true or false
|
||||
|
||||
|
@ -67,8 +70,6 @@ function Xnes:init(dims, popsize, learning_rate, sigma, antithetic)
|
|||
-- note: this is technically the co-standard-deviation.
|
||||
-- you can imagine the "s" standing for "sqrt" if you like.
|
||||
self.covars = make_covars(self.dims, self.sigma, self.covars)
|
||||
|
||||
self.mean_adapt = 1.0
|
||||
end
|
||||
|
||||
function Xnes:params(new_mean)
|
||||
|
@ -182,13 +183,12 @@ function Xnes:tell(scored, noise)
|
|||
end
|
||||
|
||||
for i, v in ipairs(self.mean) do
|
||||
self.mean[i] = v + self.mean_adapt * step[i]
|
||||
self.mean[i] = v + self.weight_rate * step[i]
|
||||
end
|
||||
|
||||
local lr = self.learning_rate * 0.5
|
||||
self.sigma = self.sigma * exp(lr * g_sigma)
|
||||
self.sigma = self.sigma * exp(self.sigma_rate * 0.5 * g_sigma)
|
||||
for i, v in ipairs(self.covars) do
|
||||
self.covars[i] = v * exp(lr * g_covars[i])
|
||||
self.covars[i] = v * exp(self.covar_rate * 0.5 * g_covars[i])
|
||||
end
|
||||
|
||||
-- bookkeeping:
|
||||
|
|
Loading…
Reference in a new issue