overhaul learning rates:

- rename mean_adapt to weight_rate
- sigma and covar update rates can be specified separately
  (sigma_rate, covar_rate)
- base decays on current rates instead of initially configured rates
  (this might break stuff)
- base_rate takes the place of learning_rate
This commit is contained in:
Connor Olding 2018-06-16 00:24:55 +02:00
parent 474bac45b8
commit f3fc95404c
5 changed files with 64 additions and 44 deletions

View File

@ -60,10 +60,13 @@ local function kinda_lipschitz(dir, pos, neg, mid)
return max(l0, l1) / (2 * dev) return max(l0, l1) / (2 * dev)
end end
function Ars:init(dims, popsize, poptop, learning_rate, sigma, antithetic) function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic)
self.dims = dims self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims))) self.popsize = popsize or 4 + (3 * floor(log(dims)))
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
self.weight_rate = base_rate
self.sigma_rate = base_rate
self.covar_rate = base_rate
self.sigma = sigma or 1 self.sigma = sigma or 1
self.antithetic = antithetic and true or false self.antithetic = antithetic and true or false
@ -187,7 +190,7 @@ function Ars:tell(scored, unperturbed_score)
end end
for i, v in ipairs(self._params) do for i, v in ipairs(self._params) do
self._params[i] = v + self.learning_rate * step[i] self._params[i] = v + self.weight_rate * step[i]
end end
self.noise = nil self.noise = nil

View File

@ -50,8 +50,8 @@ local common_cfg = {
min_refresh = 0.1, -- for SNES. min_refresh = 0.1, -- for SNES.
-- epoch-related rates: -- epoch-related rates:
learning_rate = 1.0, base_rate = 1.0, -- weight_rate, sigma_rate, and covar_rate will base on this
mean_adapt = 1.0, -- for SNES, xNES. -- if you don't specify them individually.
weight_decay = 0.0, weight_decay = 0.0,
sigma_decay = 0.0, -- for SNES, xNES. sigma_decay = 0.0, -- for SNES, xNES.
} }
@ -75,10 +75,10 @@ if preset == 'snes' then
deviation = 1.0, deviation = 1.0,
min_refresh = 0.2, min_refresh = 0.2,
learning_rate = 0.1, -- TODO: rename to learn_primary or something. weight_rate = 0.5,
mean_adapt = 0.5, -- TODO: rename to learn_secondary or something. sigma_rate = 0.1,
weight_decay = 0.02, -- note: multiplied by its std, and mean_adapt. weight_decay = 0.02, -- note: multiplied by its std, and weight_rate.
sigma_decay = 0.01, -- note: multiplied by learning_rate. sigma_decay = 0.01, -- note: multiplied by sigma_rate.
} }
elseif preset == 'snes2' then elseif preset == 'snes2' then
@ -108,8 +108,8 @@ elseif preset == 'snes2' then
epoch_trials = 100, epoch_trials = 100,
learning_rate = 0.01, weight_rate = 1.0,
mean_adapt = 1.0, sigma_rate = 0.01,
weight_decay = 0.02, weight_decay = 0.02,
sigma_decay = 0.01, sigma_decay = 0.01,
} }
@ -140,7 +140,9 @@ elseif preset == 'xnes' then
epoch_trials = 50, epoch_trials = 50,
learning_rate = 0.01, weight_rate = 1.0,
sigma_rate = 0.01,
covar_rate = 0.01,
} }
elseif preset == 'xnes2' then elseif preset == 'xnes2' then
@ -169,8 +171,9 @@ elseif preset == 'xnes2' then
epoch_trials = 10, --50, epoch_trials = 10, --50,
learning_rate = 0.04, weight_rate = 0.5,
mean_adapt = 0.5, sigma_rate = 0.04,
covar_rate = 0.04,
weight_decay = 0.004, weight_decay = 0.004,
sigma_decay = 0.00128, sigma_decay = 0.00128,
} }
@ -199,7 +202,7 @@ elseif preset == 'ars' then
epoch_trials = 25 * 2, epoch_trials = 25 * 2,
learning_rate = 1.0, weight_rate = 1.0,
weight_decay = 0.0025, weight_decay = 0.0025,
} }
@ -218,6 +221,16 @@ end
-- TODO: so, uhh.. -- TODO: so, uhh..
-- what happens when playback_mode is true but unperturbed_trial is false? -- what happens when playback_mode is true but unperturbed_trial is false?
setmetatable(cfg, {__index=common_cfg}) -- gets overridden later.
if cfg.es == 'ars' then
if cfg.weight_rate == nil then cfg.weight_rate = cfg.base_rate end
else
if cfg.weight_rate == nil then cfg.weight_rate = 1.0 end
end
if cfg.sigma_rate == nil then cfg.sigma_rate = cfg.base_rate end
if cfg.covar_rate == nil then cfg.covar_rate = cfg.sigma_rate end
setmetatable(cfg, { setmetatable(cfg, {
__index = function(t, n) __index = function(t, n)
if common_cfg[n] ~= nil then return common_cfg[n] end if common_cfg[n] ~= nil then return common_cfg[n] end

View File

@ -276,15 +276,16 @@ local function learn_from_epoch()
base_params = es:params() base_params = es:params()
-- TODO: move this all to es:decay methods.
if cfg.es == 'snes' then if cfg.es == 'snes' then
if cfg.sigma_decay > 0 then if cfg.sigma_decay > 0 then
for i, v in ipairs(es.std) do for i, v in ipairs(es.std) do
es.std[i] = v * (1 - cfg.learning_rate * cfg.sigma_decay) es.std[i] = v * (1 - es.sigma_rate * cfg.sigma_decay)
end end
end end
if cfg.weight_decay > 0 then if cfg.weight_decay > 0 then
for i, v in ipairs(base_params) do for i, v in ipairs(base_params) do
base_params[i] = v * (1 - cfg.mean_adapt * cfg.weight_decay * es.std[i]) base_params[i] = v * (1 - es.weight_rate * cfg.weight_decay * es.std[i])
end end
end end
elseif cfg.es == 'xnes' then elseif cfg.es == 'xnes' then
@ -293,7 +294,7 @@ local function learn_from_epoch()
end end
if cfg.weight_decay > 0 then if cfg.weight_decay > 0 then
for i, v in ipairs(base_params) do for i, v in ipairs(base_params) do
base_params[i] = v * (1 - cfg.mean_adapt * cfg.weight_decay) base_params[i] = v * (1 - es.weight_rate * cfg.weight_decay)
end end
end end
else else
@ -494,14 +495,11 @@ local function init()
-- if you get an out of memory error, you can't use xNES. sorry! -- if you get an out of memory error, you can't use xNES. sorry!
-- maybe there'll be a patch for FCEUX in the future. -- maybe there'll be a patch for FCEUX in the future.
es = xnes.Xnes(network.n_param, cfg.epoch_trials, es = xnes.Xnes(network.n_param, cfg.epoch_trials,
cfg.learning_rate, cfg.deviation, cfg.negate_trials) cfg.base_rate, cfg.deviation, cfg.negate_trials)
-- TODO: clean this up into an interface:
es.mean_adapt = cfg.mean_adapt
elseif cfg.es == 'snes' then elseif cfg.es == 'snes' then
es = snes.Snes(network.n_param, cfg.epoch_trials, es = snes.Snes(network.n_param, cfg.epoch_trials,
cfg.learning_rate, cfg.deviation, cfg.negate_trials) cfg.base_rate, cfg.deviation, cfg.negate_trials)
-- TODO: clean this up into an interface: -- TODO: clean this up into an interface:
es.mean_adapt = cfg.mean_adapt
es.min_refresh = cfg.min_refresh es.min_refresh = cfg.min_refresh
if exists(std_fn) then if exists(std_fn) then
@ -513,11 +511,16 @@ local function init()
end end
elseif cfg.es == 'ars' then elseif cfg.es == 'ars' then
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials, es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
cfg.learning_rate, cfg.deviation, cfg.negate_trials) cfg.base_rate, cfg.deviation, cfg.negate_trials)
else else
error("Unknown evolution strategy specified: " + tostring(cfg.es)) error("Unknown evolution strategy specified: " + tostring(cfg.es))
end end
es.weight_rate = cfg.weight_rate
es.sigma_rate = cfg.sigma_rate
es.covar_rate = cfg.covar_rate
es.rate_init = cfg.sigma_rate -- just for SNES?
es:params(network:collect()) es:params(network:collect())
end end

View File

@ -32,24 +32,25 @@ local weighted_mann_whitney = util.weighted_mann_whitney
local Snes = Base:extend() local Snes = Base:extend()
function Snes:init(dims, popsize, learning_rate, sigma, antithetic) function Snes:init(dims, popsize, base_rate, sigma, antithetic)
-- heuristic borrowed from CMA-ES: -- heuristic borrowed from CMA-ES:
self.dims = dims self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims))) self.popsize = popsize or 4 + (3 * floor(log(dims)))
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
self.weight_rate = 1.0
self.sigma_rate = base_rate
self.covar_rate = base_rate
self.sigma = sigma or 1 self.sigma = sigma or 1
self.antithetic = antithetic and true or false self.antithetic = antithetic and true or false
if self.antithetic then self.popsize = self.popsize * 2 end if self.antithetic then self.popsize = self.popsize * 2 end
self.rate_init = self.learning_rate self.rate_init = self.sigma_rate
self.mean = zeros{dims} self.mean = zeros{dims}
self.std = zeros{dims} self.std = zeros{dims}
for i=1, self.dims do self.std[i] = self.sigma end for i=1, self.dims do self.std[i] = self.sigma end
self.mean_adapt = 1.0
self.old_asked = {} self.old_asked = {}
self.old_noise = {} self.old_noise = {}
self.old_score = {} self.old_score = {}
@ -250,15 +251,15 @@ function Snes:tell(scored)
end end
for i, v in ipairs(self.mean) do for i, v in ipairs(self.mean) do
self.mean[i] = v + self.mean_adapt * step[i] self.mean[i] = v + self.weight_rate * step[i]
end end
local otherwise = {} local otherwise = {}
self.std_old = {} self.std_old = {}
for i, v in ipairs(self.std) do for i, v in ipairs(self.std) do
self.std_old[i] = v self.std_old[i] = v
self.std[i] = v * exp(self.learning_rate * 0.5 * g_std[i]) self.std[i] = v * exp(self.sigma_rate * 0.5 * g_std[i])
otherwise[i] = v * exp(self.learning_rate * 0.75 * g_std[i]) otherwise[i] = v * exp(self.sigma_rate * 0.75 * g_std[i])
end end
self:adapt(asked, otherwise, utility) self:adapt(asked, otherwise, utility)
@ -283,11 +284,11 @@ function Snes:adapt(asked, otherwise, qualities)
--print("p:", p) --print("p:", p)
if p < 0.5 - 1 / (3 * (self.dims + 1)) then if p < 0.5 - 1 / (3 * (self.dims + 1)) then
self.learning_rate = 0.9 * self.learning_rate + 0.1 * self.rate_init self.sigma_rate = 0.9 * self.sigma_rate + 0.1 * self.rate_init
print("learning rate -:", self.learning_rate) print("learning rate -:", self.sigma_rate)
else else
self.learning_rate = min(1.1 * self.learning_rate, 1) self.sigma_rate = min(1.1 * self.sigma_rate, 1)
print("learning rate +:", self.learning_rate) print("learning rate +:", self.sigma_rate)
end end
end end

View File

@ -51,11 +51,14 @@ local function make_covars(dims, sigma, out)
return covars return covars
end end
function Xnes:init(dims, popsize, learning_rate, sigma, antithetic) function Xnes:init(dims, popsize, base_rate, sigma, antithetic)
-- heuristic borrowed from CMA-ES: -- heuristic borrowed from CMA-ES:
self.dims = dims self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims))) self.popsize = popsize or 4 + (3 * floor(log(dims)))
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
self.weight_rate = 1.0
self.sigma_rate = base_rate
self.covar_rate = base_rate
self.sigma = sigma or 1 self.sigma = sigma or 1
self.antithetic = antithetic and true or false self.antithetic = antithetic and true or false
@ -67,8 +70,6 @@ function Xnes:init(dims, popsize, learning_rate, sigma, antithetic)
-- note: this is technically the co-standard-deviation. -- note: this is technically the co-standard-deviation.
-- you can imagine the "s" standing for "sqrt" if you like. -- you can imagine the "s" standing for "sqrt" if you like.
self.covars = make_covars(self.dims, self.sigma, self.covars) self.covars = make_covars(self.dims, self.sigma, self.covars)
self.mean_adapt = 1.0
end end
function Xnes:params(new_mean) function Xnes:params(new_mean)
@ -182,13 +183,12 @@ function Xnes:tell(scored, noise)
end end
for i, v in ipairs(self.mean) do for i, v in ipairs(self.mean) do
self.mean[i] = v + self.mean_adapt * step[i] self.mean[i] = v + self.weight_rate * step[i]
end end
local lr = self.learning_rate * 0.5 self.sigma = self.sigma * exp(self.sigma_rate * 0.5 * g_sigma)
self.sigma = self.sigma * exp(lr * g_sigma)
for i, v in ipairs(self.covars) do for i, v in ipairs(self.covars) do
self.covars[i] = v * exp(lr * g_covars[i]) self.covars[i] = v * exp(self.covar_rate * 0.5 * g_covars[i])
end end
-- bookkeeping: -- bookkeeping: