add AMSgrad optimizer and logging

This commit is contained in:
Connor Olding 2018-05-03 15:33:17 +02:00
parent c7c657513e
commit 90922a2bc3
2 changed files with 122 additions and 19 deletions

View File

@ -2,7 +2,19 @@ local function approx_cossim(dim)
return math.pow(1.521 * dim - 0.521, -0.5026)
end
local function intmap(x)
-- 0 -> 1.0
-- -1 -> 0.316
-- -2 -> 0.1
-- -3 -> 0.0316
-- -4 -> 0.01
-- etc.
return math.pow(10, x / 2)
end
local cfg = {
log_fn = 'log.csv', -- can be nil to disable logging.
defer_prints = true,
playable_mode = false,
@ -16,20 +28,23 @@ local cfg = {
det_epsilon = false, -- take random actions with probability eps.
graycode = false,
epoch_trials = 5,
epoch_top_trials = 2, -- new with ARS.
unperturbed_trial = false, -- do a trial without any noise.
epoch_trials = 64 * (7/8),
epoch_top_trials = 40 * (7/8), -- new with ARS.
unperturbed_trial = true, -- do a trial without any noise.
negate_trials = true, -- try pairs of normal and negated noise directions.
time_inputs = true, -- binary inputs of global frame count
-- ^ note that this now doubles the effective trials.
deviation = 0.32,
--learning_rate = 0.01 / approx_cossim(7051)
learning_rate = 0.32,
--learning_rate = 0.0032 / approx_cossim(66573)
--learning_rate = 0.0056 / approx_cossim(66573)
weight_decay = 0.0032,
deviation = intmap(-3),
learning_rate = intmap(-4),
weight_decay = intmap(-6),
cap_time = 200, --400
adamant = true, -- run steps through AMSgrad.
adam_b1 = math.pow(10, -1 / 15),
adam_b2 = math.pow(10, -1 / 100),
adam_eps = intmap(-8),
adam_debias = false,
cap_time = 222, --400
timer_loser = 1/2,
decrement_reward = false, -- bad idea, encourages mario to kill himself

106
main.lua
View File

@ -16,6 +16,9 @@ local trial_neg = true
local trial_noise = {}
local trial_rewards = {}
local trials_remaining = 0
local mom1 -- first moments in AMSgrad.
local mom2 -- second moments in AMSgrad.
local mom2max -- running element-wise maximum of mom2.
local trial_frames = 0
local total_frames = 0
@ -55,6 +58,7 @@ local print = print
local ipairs = ipairs
local pairs = pairs
local select = select
local open = io.open
local abs = math.abs
local floor = math.floor
local ceil = math.ceil
@ -88,6 +92,39 @@ local gui = gui
-- utilities.
local log_map = {
epoch = 1,
trial_mean = 2,
trial_std = 3,
delta_mean = 4,
delta_std = 5,
step_std = 6,
adam_std = 7,
weight_mean = 8,
weight_std = 9,
}
local function log_csv(t)
if cfg.log_fn == nil then return end
local f = open(cfg.log_fn, 'a')
if f == nil then error("Failed to open log file "..cfg.log_fn) end
local values = {}
for k, v in pairs(t) do
local i = log_map[k]
if i == nil then error("Unexpected log key "..tostring(k)) end
values[i] = v
end
for k, i in pairs(log_map) do
if values[i] == nil then error("Missing log key "..tostring(k)) end
end
for i, v in ipairs(values) do
f:write(tostring(v))
if i ~= #values then f:write(",") end
end
f:write('\n')
f:close()
end
local function boolean_xor(a, b)
if a and b then return false end
if not a and not b then return false end
@ -444,7 +481,6 @@ local function prepare_epoch()
local noise = nn.zeros(#base_params)
-- NOTE: change in implementation: deviation is multiplied here
-- and ONLY here now.
--if i % 2 == 0 then -- FIXME: just messing around.
if cfg.graycode then
--local precision = 1 / cfg.deviation
--print(cfg.deviation, precision)
@ -569,10 +605,17 @@ local function learn_from_epoch()
local step = nn.zeros(#base_params)
-- new stuff
local delta_rewards -- only used for logging.
local best_rewards
if cfg.negate_trials then
delta_rewards = {}
for i = 1, cfg.epoch_trials do
local ind = (i - 1) * 2 + 1
local pos = trial_rewards[ind + 0]
local neg = trial_rewards[ind + 1]
delta_rewards[i] = abs(pos - neg)
end
-- select one (the best) reward of each pos/neg pair.
best_rewards = {}
for i = 1, cfg.epoch_trials do
@ -602,12 +645,12 @@ local function learn_from_epoch()
end
--print("top:", top_rewards)
local delta_rewards = {} -- only used for printing.
local top_delta_rewards = {} -- only used for printing.
for i, ind in ipairs(indices) do
local sind = (ind - 1) * 2 + 1
delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1])
top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1])
end
print("best deltas:", delta_rewards)
print("best deltas:", top_delta_rewards)
local _, reward_dev = calc_mean_dev(top_rewards)
--print("mean, dev:", _, reward_dev)
@ -628,14 +671,59 @@ local function learn_from_epoch()
end
local step_mean, step_dev = calc_mean_dev(step)
if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end
--print("step stddev:", step_dev)
print("full step stddev:", cfg.learning_rate * step_dev)
print("step mean:", step_mean)
print("step stddev:", step_dev)
--print("full step stddev:", cfg.learning_rate * step_dev)
local momstep_mean, momstep_dev
if cfg.adamant then
if mom1 == nil then mom1 = nn.zeros(#step) end
if mom2 == nil then mom2 = nn.zeros(#step) end
if mom2max == nil then mom2max = nn.zeros(#step) end
local b1_t = pow(cfg.adam_b1, epoch_i)
local b2_t = pow(cfg.adam_b2, epoch_i)
-- NOTE: with LuaJIT, splitting this loop would
-- almost certainly be faster.
for i, v in ipairs(step) do
mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v
mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v
mom2max[i] = max(mom2[i], mom2max[i])
if cfg.adam_debias then
local num = (mom1[i] / (1 - b1_t))
local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps
step[i] = num / den
else
step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps)
end
end
momstep_mean, momstep_dev = calc_mean_dev(step)
print("amsgrad mean:", momstep_mean)
print("amsgrad stddev:", momstep_dev)
end
for i, v in ipairs(base_params) do
base_params[i] = v + cfg.learning_rate * step[i] - cfg.weight_decay * v
end
local trial_mean, trial_std = calc_mean_dev(trial_rewards)
local delta_mean, delta_std = calc_mean_dev(delta_rewards)
local weight_mean, weight_std = calc_mean_dev(base_params)
log_csv{
epoch = epoch_i,
trial_mean = trial_mean,
trial_std = trial_std,
delta_mean = delta_mean,
delta_std = delta_std,
step_std = step_dev,
adam_std = momstep_dev,
weight_mean = weight_mean,
weight_std = weight_std,
}
if cfg.enable_network then
network:distribute(base_params)
network:save()