overhaul learning rates:

- rename mean_adapt to weight_rate
- sigma and covar update rates can be specified separately
  (sigma_rate, covar_rate)
- base decays on current rates instead of initially configured rates
  (this might break stuff)
- base_rate takes the place of learning_rate
This commit is contained in:
Connor Olding 2018-06-16 00:24:55 +02:00
parent 474bac45b8
commit f3fc95404c
5 changed files with 64 additions and 44 deletions

View File

@ -60,10 +60,13 @@ local function kinda_lipschitz(dir, pos, neg, mid)
return max(l0, l1) / (2 * dev)
end
function Ars:init(dims, popsize, poptop, learning_rate, sigma, antithetic)
function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic)
self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims)))
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
self.weight_rate = base_rate
self.sigma_rate = base_rate
self.covar_rate = base_rate
self.sigma = sigma or 1
self.antithetic = antithetic and true or false
@ -187,7 +190,7 @@ function Ars:tell(scored, unperturbed_score)
end
for i, v in ipairs(self._params) do
self._params[i] = v + self.learning_rate * step[i]
self._params[i] = v + self.weight_rate * step[i]
end
self.noise = nil

View File

@ -50,8 +50,8 @@ local common_cfg = {
min_refresh = 0.1, -- for SNES.
-- epoch-related rates:
learning_rate = 1.0,
mean_adapt = 1.0, -- for SNES, xNES.
base_rate = 1.0, -- weight_rate, sigma_rate, and covar_rate will base on this
-- if you don't specify them individually.
weight_decay = 0.0,
sigma_decay = 0.0, -- for SNES, xNES.
}
@ -75,10 +75,10 @@ if preset == 'snes' then
deviation = 1.0,
min_refresh = 0.2,
learning_rate = 0.1, -- TODO: rename to learn_primary or something.
mean_adapt = 0.5, -- TODO: rename to learn_secondary or something.
weight_decay = 0.02, -- note: multiplied by its std, and mean_adapt.
sigma_decay = 0.01, -- note: multiplied by learning_rate.
weight_rate = 0.5,
sigma_rate = 0.1,
weight_decay = 0.02, -- note: multiplied by its std, and weight_rate.
sigma_decay = 0.01, -- note: multiplied by sigma_rate.
}
elseif preset == 'snes2' then
@ -108,8 +108,8 @@ elseif preset == 'snes2' then
epoch_trials = 100,
learning_rate = 0.01,
mean_adapt = 1.0,
weight_rate = 1.0,
sigma_rate = 0.01,
weight_decay = 0.02,
sigma_decay = 0.01,
}
@ -140,7 +140,9 @@ elseif preset == 'xnes' then
epoch_trials = 50,
learning_rate = 0.01,
weight_rate = 1.0,
sigma_rate = 0.01,
covar_rate = 0.01,
}
elseif preset == 'xnes2' then
@ -169,8 +171,9 @@ elseif preset == 'xnes2' then
epoch_trials = 10, --50,
learning_rate = 0.04,
mean_adapt = 0.5,
weight_rate = 0.5,
sigma_rate = 0.04,
covar_rate = 0.04,
weight_decay = 0.004,
sigma_decay = 0.00128,
}
@ -199,7 +202,7 @@ elseif preset == 'ars' then
epoch_trials = 25 * 2,
learning_rate = 1.0,
weight_rate = 1.0,
weight_decay = 0.0025,
}
@ -218,6 +221,16 @@ end
-- TODO: so, uhh..
-- what happens when playback_mode is true but unperturbed_trial is false?
setmetatable(cfg, {__index=common_cfg}) -- gets overridden later.
if cfg.es == 'ars' then
if cfg.weight_rate == nil then cfg.weight_rate = cfg.base_rate end
else
if cfg.weight_rate == nil then cfg.weight_rate = 1.0 end
end
if cfg.sigma_rate == nil then cfg.sigma_rate = cfg.base_rate end
if cfg.covar_rate == nil then cfg.covar_rate = cfg.sigma_rate end
setmetatable(cfg, {
__index = function(t, n)
if common_cfg[n] ~= nil then return common_cfg[n] end

View File

@ -276,15 +276,16 @@ local function learn_from_epoch()
base_params = es:params()
-- TODO: move this all to es:decay methods.
if cfg.es == 'snes' then
if cfg.sigma_decay > 0 then
for i, v in ipairs(es.std) do
es.std[i] = v * (1 - cfg.learning_rate * cfg.sigma_decay)
es.std[i] = v * (1 - es.sigma_rate * cfg.sigma_decay)
end
end
if cfg.weight_decay > 0 then
for i, v in ipairs(base_params) do
base_params[i] = v * (1 - cfg.mean_adapt * cfg.weight_decay * es.std[i])
base_params[i] = v * (1 - es.weight_rate * cfg.weight_decay * es.std[i])
end
end
elseif cfg.es == 'xnes' then
@ -293,7 +294,7 @@ local function learn_from_epoch()
end
if cfg.weight_decay > 0 then
for i, v in ipairs(base_params) do
base_params[i] = v * (1 - cfg.mean_adapt * cfg.weight_decay)
base_params[i] = v * (1 - es.weight_rate * cfg.weight_decay)
end
end
else
@ -494,14 +495,11 @@ local function init()
-- if you get an out of memory error, you can't use xNES. sorry!
-- maybe there'll be a patch for FCEUX in the future.
es = xnes.Xnes(network.n_param, cfg.epoch_trials,
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
-- TODO: clean this up into an interface:
es.mean_adapt = cfg.mean_adapt
cfg.base_rate, cfg.deviation, cfg.negate_trials)
elseif cfg.es == 'snes' then
es = snes.Snes(network.n_param, cfg.epoch_trials,
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
cfg.base_rate, cfg.deviation, cfg.negate_trials)
-- TODO: clean this up into an interface:
es.mean_adapt = cfg.mean_adapt
es.min_refresh = cfg.min_refresh
if exists(std_fn) then
@ -513,11 +511,16 @@ local function init()
end
elseif cfg.es == 'ars' then
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
cfg.base_rate, cfg.deviation, cfg.negate_trials)
else
error("Unknown evolution strategy specified: " + tostring(cfg.es))
end
es.weight_rate = cfg.weight_rate
es.sigma_rate = cfg.sigma_rate
es.covar_rate = cfg.covar_rate
es.rate_init = cfg.sigma_rate -- just for SNES?
es:params(network:collect())
end

View File

@ -32,24 +32,25 @@ local weighted_mann_whitney = util.weighted_mann_whitney
local Snes = Base:extend()
function Snes:init(dims, popsize, learning_rate, sigma, antithetic)
function Snes:init(dims, popsize, base_rate, sigma, antithetic)
-- heuristic borrowed from CMA-ES:
self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims)))
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
self.weight_rate = 1.0
self.sigma_rate = base_rate
self.covar_rate = base_rate
self.sigma = sigma or 1
self.antithetic = antithetic and true or false
if self.antithetic then self.popsize = self.popsize * 2 end
self.rate_init = self.learning_rate
self.rate_init = self.sigma_rate
self.mean = zeros{dims}
self.std = zeros{dims}
for i=1, self.dims do self.std[i] = self.sigma end
self.mean_adapt = 1.0
self.old_asked = {}
self.old_noise = {}
self.old_score = {}
@ -250,15 +251,15 @@ function Snes:tell(scored)
end
for i, v in ipairs(self.mean) do
self.mean[i] = v + self.mean_adapt * step[i]
self.mean[i] = v + self.weight_rate * step[i]
end
local otherwise = {}
self.std_old = {}
for i, v in ipairs(self.std) do
self.std_old[i] = v
self.std[i] = v * exp(self.learning_rate * 0.5 * g_std[i])
otherwise[i] = v * exp(self.learning_rate * 0.75 * g_std[i])
self.std[i] = v * exp(self.sigma_rate * 0.5 * g_std[i])
otherwise[i] = v * exp(self.sigma_rate * 0.75 * g_std[i])
end
self:adapt(asked, otherwise, utility)
@ -283,11 +284,11 @@ function Snes:adapt(asked, otherwise, qualities)
--print("p:", p)
if p < 0.5 - 1 / (3 * (self.dims + 1)) then
self.learning_rate = 0.9 * self.learning_rate + 0.1 * self.rate_init
print("learning rate -:", self.learning_rate)
self.sigma_rate = 0.9 * self.sigma_rate + 0.1 * self.rate_init
print("learning rate -:", self.sigma_rate)
else
self.learning_rate = min(1.1 * self.learning_rate, 1)
print("learning rate +:", self.learning_rate)
self.sigma_rate = min(1.1 * self.sigma_rate, 1)
print("learning rate +:", self.sigma_rate)
end
end

View File

@ -51,11 +51,14 @@ local function make_covars(dims, sigma, out)
return covars
end
function Xnes:init(dims, popsize, learning_rate, sigma, antithetic)
function Xnes:init(dims, popsize, base_rate, sigma, antithetic)
-- heuristic borrowed from CMA-ES:
self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims)))
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
self.weight_rate = 1.0
self.sigma_rate = base_rate
self.covar_rate = base_rate
self.sigma = sigma or 1
self.antithetic = antithetic and true or false
@ -67,8 +70,6 @@ function Xnes:init(dims, popsize, learning_rate, sigma, antithetic)
-- note: this is technically the co-standard-deviation.
-- you can imagine the "s" standing for "sqrt" if you like.
self.covars = make_covars(self.dims, self.sigma, self.covars)
self.mean_adapt = 1.0
end
function Xnes:params(new_mean)
@ -182,13 +183,12 @@ function Xnes:tell(scored, noise)
end
for i, v in ipairs(self.mean) do
self.mean[i] = v + self.mean_adapt * step[i]
self.mean[i] = v + self.weight_rate * step[i]
end
local lr = self.learning_rate * 0.5
self.sigma = self.sigma * exp(lr * g_sigma)
self.sigma = self.sigma * exp(self.sigma_rate * 0.5 * g_sigma)
for i, v in ipairs(self.covars) do
self.covars[i] = v * exp(lr * g_covars[i])
self.covars[i] = v * exp(self.covar_rate * 0.5 * g_covars[i])
end
-- bookkeeping: