add mean adaptation hyperparameter
This commit is contained in:
parent
47eb173dac
commit
d87b8e7118
3 changed files with 9 additions and 6 deletions
|
@ -32,6 +32,9 @@ local common_cfg = {
|
||||||
time_inputs = true, -- binary inputs of global frame count
|
time_inputs = true, -- binary inputs of global frame count
|
||||||
normalize_inputs = false,
|
normalize_inputs = false,
|
||||||
|
|
||||||
|
learning_rate = 1.0,
|
||||||
|
mean_adapt = 1.0, -- for xNES
|
||||||
|
|
||||||
es = 'ars',
|
es = 'ars',
|
||||||
ars_lips = false,
|
ars_lips = false,
|
||||||
adamant = false, -- run steps through AMSgrad.
|
adamant = false, -- run steps through AMSgrad.
|
||||||
|
|
2
main.lua
2
main.lua
|
@ -456,6 +456,8 @@ local function init()
|
||||||
-- maybe there'll be a patch for FCEUX in the future.
|
-- maybe there'll be a patch for FCEUX in the future.
|
||||||
es = xnes.Xnes(network.n_param, cfg.epoch_trials,
|
es = xnes.Xnes(network.n_param, cfg.epoch_trials,
|
||||||
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
||||||
|
-- TODO: clean this up into an interface:
|
||||||
|
es.mean_adapt = cfg.mean_adapt
|
||||||
elseif cfg.es == 'ars' then
|
elseif cfg.es == 'ars' then
|
||||||
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
||||||
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
||||||
|
|
10
xnes.lua
10
xnes.lua
|
@ -71,17 +71,15 @@ function Xnes:init(dims, popsize, learning_rate, sigma, antithetic)
|
||||||
--self.log_sigma = log(self.sigma)
|
--self.log_sigma = log(self.sigma)
|
||||||
--self.log_covars = zeros{dims, dims}
|
--self.log_covars = zeros{dims, dims}
|
||||||
--for i, v in ipairs(self.covars) do self.log_covars[i] = log(v) end
|
--for i, v in ipairs(self.covars) do self.log_covars[i] = log(v) end
|
||||||
|
|
||||||
|
self.mean_adapt = 1.0
|
||||||
end
|
end
|
||||||
|
|
||||||
function Xnes:params(new_mean, new_covars)
|
function Xnes:params(new_mean)
|
||||||
if new_mean ~= nil then
|
if new_mean ~= nil then
|
||||||
assert(#self.mean == #new_mean, "new parameters have the wrong size")
|
assert(#self.mean == #new_mean, "new parameters have the wrong size")
|
||||||
for i, v in ipairs(new_mean) do self.mean[i] = v end
|
for i, v in ipairs(new_mean) do self.mean[i] = v end
|
||||||
end
|
end
|
||||||
if new_covars ~= nil then
|
|
||||||
-- TODO: assert determinant of new_covars is 1.
|
|
||||||
error("TODO")
|
|
||||||
end
|
|
||||||
return self.mean
|
return self.mean
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -183,7 +181,7 @@ function Xnes:tell(scored, noise)
|
||||||
|
|
||||||
local dotted = dot_mv(self.covars, g_delta)
|
local dotted = dot_mv(self.covars, g_delta)
|
||||||
for i, v in ipairs(self.mean) do
|
for i, v in ipairs(self.mean) do
|
||||||
self.mean[i] = v + self.sigma * dotted[i]
|
self.mean[i] = v + self.mean_adapt * self.sigma * dotted[i]
|
||||||
end
|
end
|
||||||
|
|
||||||
--[[
|
--[[
|
||||||
|
|
Loading…
Reference in a new issue