restore step logging, remove adamant (for now)
This commit is contained in:
parent
5c64fcf395
commit
b4e49d08b9
4 changed files with 24 additions and 49 deletions
26
ars.lua
26
ars.lua
|
@ -51,30 +51,6 @@ local function kinda_lipschitz(dir, pos, neg, mid)
|
|||
return max(l0, l1) / (2 * dev)
|
||||
end
|
||||
|
||||
local function amsgrad(step) -- in-place! -- TODO: fix this.
|
||||
if mom1 == nil then mom1 = nn.zeros(#step) end
|
||||
if mom2 == nil then mom2 = nn.zeros(#step) end
|
||||
if mom2max == nil then mom2max = nn.zeros(#step) end
|
||||
|
||||
local b1_t = pow(cfg.adam_b1, epoch_i)
|
||||
local b2_t = pow(cfg.adam_b2, epoch_i)
|
||||
|
||||
-- NOTE: with LuaJIT, splitting this loop would
|
||||
-- almost certainly be faster.
|
||||
for i, v in ipairs(step) do
|
||||
mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v
|
||||
mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v
|
||||
mom2max[i] = max(mom2[i], mom2max[i])
|
||||
if cfg.adam_debias then
|
||||
local num = (mom1[i] / (1 - b1_t))
|
||||
local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps
|
||||
step[i] = num / den
|
||||
else
|
||||
step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function Ars:init(dims, popsize, poptop, learning_rate, sigma, antithetic)
|
||||
self.dims = dims
|
||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||
|
@ -201,6 +177,8 @@ function Ars:tell(scored, unperturbed_score)
|
|||
end
|
||||
|
||||
self.noise = nil
|
||||
|
||||
return step
|
||||
end
|
||||
|
||||
return {
|
||||
|
|
29
main.lua
29
main.lua
|
@ -17,9 +17,6 @@ local trial_neg = true
|
|||
local trial_params --= {}
|
||||
local trial_rewards = {}
|
||||
local trials_remaining = 0
|
||||
local mom1 -- first moments in AMSgrad.
|
||||
local mom2 -- second moments in AMSgrad.
|
||||
local mom2max -- running element-wise maximum of mom2.
|
||||
local es -- evolution strategy.
|
||||
|
||||
local trial_frames = 0
|
||||
|
@ -114,10 +111,9 @@ local log_map = {
|
|||
delta_mean = 4,
|
||||
delta_std = 5,
|
||||
step_std = 6,
|
||||
adam_std = 7,
|
||||
weight_mean = 8,
|
||||
weight_std = 9,
|
||||
test_trial = 10,
|
||||
weight_mean = 7,
|
||||
weight_std = 8,
|
||||
test_trial = 9,
|
||||
}
|
||||
|
||||
local function log_csv(t)
|
||||
|
@ -263,28 +259,16 @@ local function learn_from_epoch()
|
|||
end
|
||||
end
|
||||
|
||||
local step
|
||||
if cfg.es == 'ars' and cfg.ars_lips then
|
||||
es:tell(trial_rewards, current_cost)
|
||||
step = es:tell(trial_rewards, current_cost)
|
||||
else
|
||||
es:tell(trial_rewards)
|
||||
step = es:tell(trial_rewards)
|
||||
end
|
||||
|
||||
local step_mean, step_dev = 0, 0
|
||||
--[[ TODO
|
||||
local step_mean, step_dev = calc_mean_dev(step)
|
||||
print("step mean:", step_mean)
|
||||
print("step stddev:", step_dev)
|
||||
--]]
|
||||
|
||||
local momstep_mean, momstep_dev = 0, 0
|
||||
--[[ TODO
|
||||
if cfg.adamant then
|
||||
amsgrad(step)
|
||||
momstep_mean, momstep_dev = calc_mean_dev(step)
|
||||
print("amsgrad mean:", momstep_mean)
|
||||
print("amsgrad stddev:", momstep_dev)
|
||||
end
|
||||
--]]
|
||||
|
||||
base_params = es:params()
|
||||
|
||||
|
@ -320,7 +304,6 @@ local function learn_from_epoch()
|
|||
delta_mean = delta_mean,
|
||||
delta_std = delta_std,
|
||||
step_std = step_dev,
|
||||
adam_std = momstep_dev,
|
||||
weight_mean = weight_mean,
|
||||
weight_std = weight_std,
|
||||
test_trial = current_cost or 0,
|
||||
|
|
9
snes.lua
9
snes.lua
|
@ -244,8 +244,13 @@ function Snes:tell(scored)
|
|||
end
|
||||
end
|
||||
|
||||
local step = {}
|
||||
for i, v in ipairs(g_mean) do
|
||||
step[i] = self.std[i] * v
|
||||
end
|
||||
|
||||
for i, v in ipairs(self.mean) do
|
||||
self.mean[i] = v + self.mean_adapt * self.std[i] * g_mean[i]
|
||||
self.mean[i] = v + self.mean_adapt * step[i]
|
||||
end
|
||||
|
||||
local otherwise = {}
|
||||
|
@ -257,6 +262,8 @@ function Snes:tell(scored)
|
|||
end
|
||||
|
||||
self:adapt(asked, otherwise, utility)
|
||||
|
||||
return step
|
||||
end
|
||||
|
||||
function Snes:adapt(asked, otherwise, qualities)
|
||||
|
|
9
xnes.lua
9
xnes.lua
|
@ -179,9 +179,14 @@ function Xnes:tell(scored, noise)
|
|||
|
||||
-- finally, update according to the gradients.
|
||||
|
||||
local step = {}
|
||||
for i, v in ipairs(dotted) do
|
||||
step[i] = self.sigma * v
|
||||
end
|
||||
|
||||
local dotted = dot_mv(self.covars, g_delta)
|
||||
for i, v in ipairs(self.mean) do
|
||||
self.mean[i] = v + self.mean_adapt * self.sigma * dotted[i]
|
||||
self.mean[i] = v + self.mean_adapt * step[i]
|
||||
end
|
||||
|
||||
--[[
|
||||
|
@ -201,6 +206,8 @@ function Xnes:tell(scored, noise)
|
|||
--self.sigma = exp(self.log_sigma)
|
||||
--for i, v in ipairs(self.log_covars) do self.covars[i] = exp(v) end
|
||||
self.noise = nil
|
||||
|
||||
return step
|
||||
end
|
||||
|
||||
return {
|
||||
|
|
Loading…
Reference in a new issue