restore step logging, remove adamant (for now)

This commit is contained in:
Connor Olding 2018-06-13 01:36:40 +02:00
parent 5c64fcf395
commit b4e49d08b9
4 changed files with 24 additions and 49 deletions

26
ars.lua
View file

@ -51,30 +51,6 @@ local function kinda_lipschitz(dir, pos, neg, mid)
return max(l0, l1) / (2 * dev)
end
local function amsgrad(step) -- in-place! -- TODO: fix this.
if mom1 == nil then mom1 = nn.zeros(#step) end
if mom2 == nil then mom2 = nn.zeros(#step) end
if mom2max == nil then mom2max = nn.zeros(#step) end
local b1_t = pow(cfg.adam_b1, epoch_i)
local b2_t = pow(cfg.adam_b2, epoch_i)
-- NOTE: with LuaJIT, splitting this loop would
-- almost certainly be faster.
for i, v in ipairs(step) do
mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v
mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v
mom2max[i] = max(mom2[i], mom2max[i])
if cfg.adam_debias then
local num = (mom1[i] / (1 - b1_t))
local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps
step[i] = num / den
else
step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps)
end
end
end
function Ars:init(dims, popsize, poptop, learning_rate, sigma, antithetic)
self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims)))
@ -201,6 +177,8 @@ function Ars:tell(scored, unperturbed_score)
end
self.noise = nil
return step
end
return {

View file

@ -17,9 +17,6 @@ local trial_neg = true
local trial_params --= {}
local trial_rewards = {}
local trials_remaining = 0
local mom1 -- first moments in AMSgrad.
local mom2 -- second moments in AMSgrad.
local mom2max -- running element-wise maximum of mom2.
local es -- evolution strategy.
local trial_frames = 0
@ -114,10 +111,9 @@ local log_map = {
delta_mean = 4,
delta_std = 5,
step_std = 6,
adam_std = 7,
weight_mean = 8,
weight_std = 9,
test_trial = 10,
weight_mean = 7,
weight_std = 8,
test_trial = 9,
}
local function log_csv(t)
@ -263,28 +259,16 @@ local function learn_from_epoch()
end
end
local step
if cfg.es == 'ars' and cfg.ars_lips then
es:tell(trial_rewards, current_cost)
step = es:tell(trial_rewards, current_cost)
else
es:tell(trial_rewards)
step = es:tell(trial_rewards)
end
local step_mean, step_dev = 0, 0
--[[ TODO
local step_mean, step_dev = calc_mean_dev(step)
print("step mean:", step_mean)
print("step stddev:", step_dev)
--]]
local momstep_mean, momstep_dev = 0, 0
--[[ TODO
if cfg.adamant then
amsgrad(step)
momstep_mean, momstep_dev = calc_mean_dev(step)
print("amsgrad mean:", momstep_mean)
print("amsgrad stddev:", momstep_dev)
end
--]]
base_params = es:params()
@ -320,7 +304,6 @@ local function learn_from_epoch()
delta_mean = delta_mean,
delta_std = delta_std,
step_std = step_dev,
adam_std = momstep_dev,
weight_mean = weight_mean,
weight_std = weight_std,
test_trial = current_cost or 0,

View file

@ -244,8 +244,13 @@ function Snes:tell(scored)
end
end
local step = {}
for i, v in ipairs(g_mean) do
step[i] = self.std[i] * v
end
for i, v in ipairs(self.mean) do
self.mean[i] = v + self.mean_adapt * self.std[i] * g_mean[i]
self.mean[i] = v + self.mean_adapt * step[i]
end
local otherwise = {}
@ -257,6 +262,8 @@ function Snes:tell(scored)
end
self:adapt(asked, otherwise, utility)
return step
end
function Snes:adapt(asked, otherwise, qualities)

View file

@ -179,9 +179,14 @@ function Xnes:tell(scored, noise)
-- finally, update according to the gradients.
local step = {}
for i, v in ipairs(dotted) do
step[i] = self.sigma * v
end
local dotted = dot_mv(self.covars, g_delta)
for i, v in ipairs(self.mean) do
self.mean[i] = v + self.mean_adapt * self.sigma * dotted[i]
self.mean[i] = v + self.mean_adapt * step[i]
end
--[[
@ -201,6 +206,8 @@ function Xnes:tell(scored, noise)
--self.sigma = exp(self.log_sigma)
--for i, v in ipairs(self.log_covars) do self.covars[i] = exp(v) end
self.noise = nil
return step
end
return {