add AMSgrad optimizer and logging
This commit is contained in:
parent
c7c657513e
commit
90922a2bc3
2 changed files with 122 additions and 19 deletions
35
config.lua
35
config.lua
|
@ -2,7 +2,19 @@ local function approx_cossim(dim)
|
|||
return math.pow(1.521 * dim - 0.521, -0.5026)
|
||||
end
|
||||
|
||||
local function intmap(x)
|
||||
-- 0 -> 1.0
|
||||
-- -1 -> 0.316
|
||||
-- -2 -> 0.1
|
||||
-- -3 -> 0.0316
|
||||
-- -4 -> 0.01
|
||||
-- etc.
|
||||
return math.pow(10, x / 2)
|
||||
end
|
||||
|
||||
local cfg = {
|
||||
log_fn = 'log.csv', -- can be nil to disable logging.
|
||||
|
||||
defer_prints = true,
|
||||
|
||||
playable_mode = false,
|
||||
|
@ -16,20 +28,23 @@ local cfg = {
|
|||
det_epsilon = false, -- take random actions with probability eps.
|
||||
|
||||
graycode = false,
|
||||
epoch_trials = 5,
|
||||
epoch_top_trials = 2, -- new with ARS.
|
||||
unperturbed_trial = false, -- do a trial without any noise.
|
||||
epoch_trials = 64 * (7/8),
|
||||
epoch_top_trials = 40 * (7/8), -- new with ARS.
|
||||
unperturbed_trial = true, -- do a trial without any noise.
|
||||
negate_trials = true, -- try pairs of normal and negated noise directions.
|
||||
time_inputs = true, -- binary inputs of global frame count
|
||||
-- ^ note that this now doubles the effective trials.
|
||||
deviation = 0.32,
|
||||
--learning_rate = 0.01 / approx_cossim(7051)
|
||||
learning_rate = 0.32,
|
||||
--learning_rate = 0.0032 / approx_cossim(66573)
|
||||
--learning_rate = 0.0056 / approx_cossim(66573)
|
||||
weight_decay = 0.0032,
|
||||
deviation = intmap(-3),
|
||||
learning_rate = intmap(-4),
|
||||
weight_decay = intmap(-6),
|
||||
|
||||
cap_time = 200, --400
|
||||
adamant = true, -- run steps through AMSgrad.
|
||||
adam_b1 = math.pow(10, -1 / 15),
|
||||
adam_b2 = math.pow(10, -1 / 100),
|
||||
adam_eps = intmap(-8),
|
||||
adam_debias = false,
|
||||
|
||||
cap_time = 222, --400
|
||||
timer_loser = 1/2,
|
||||
decrement_reward = false, -- bad idea, encourages mario to kill himself
|
||||
|
||||
|
|
106
main.lua
106
main.lua
|
@ -16,6 +16,9 @@ local trial_neg = true
|
|||
local trial_noise = {}
|
||||
local trial_rewards = {}
|
||||
local trials_remaining = 0
|
||||
local mom1 -- first moments in AMSgrad.
|
||||
local mom2 -- second moments in AMSgrad.
|
||||
local mom2max -- running element-wise maximum of mom2.
|
||||
|
||||
local trial_frames = 0
|
||||
local total_frames = 0
|
||||
|
@ -55,6 +58,7 @@ local print = print
|
|||
local ipairs = ipairs
|
||||
local pairs = pairs
|
||||
local select = select
|
||||
local open = io.open
|
||||
local abs = math.abs
|
||||
local floor = math.floor
|
||||
local ceil = math.ceil
|
||||
|
@ -88,6 +92,39 @@ local gui = gui
|
|||
|
||||
-- utilities.
|
||||
|
||||
local log_map = {
|
||||
epoch = 1,
|
||||
trial_mean = 2,
|
||||
trial_std = 3,
|
||||
delta_mean = 4,
|
||||
delta_std = 5,
|
||||
step_std = 6,
|
||||
adam_std = 7,
|
||||
weight_mean = 8,
|
||||
weight_std = 9,
|
||||
}
|
||||
|
||||
local function log_csv(t)
|
||||
if cfg.log_fn == nil then return end
|
||||
local f = open(cfg.log_fn, 'a')
|
||||
if f == nil then error("Failed to open log file "..cfg.log_fn) end
|
||||
local values = {}
|
||||
for k, v in pairs(t) do
|
||||
local i = log_map[k]
|
||||
if i == nil then error("Unexpected log key "..tostring(k)) end
|
||||
values[i] = v
|
||||
end
|
||||
for k, i in pairs(log_map) do
|
||||
if values[i] == nil then error("Missing log key "..tostring(k)) end
|
||||
end
|
||||
for i, v in ipairs(values) do
|
||||
f:write(tostring(v))
|
||||
if i ~= #values then f:write(",") end
|
||||
end
|
||||
f:write('\n')
|
||||
f:close()
|
||||
end
|
||||
|
||||
local function boolean_xor(a, b)
|
||||
if a and b then return false end
|
||||
if not a and not b then return false end
|
||||
|
@ -444,7 +481,6 @@ local function prepare_epoch()
|
|||
local noise = nn.zeros(#base_params)
|
||||
-- NOTE: change in implementation: deviation is multiplied here
|
||||
-- and ONLY here now.
|
||||
--if i % 2 == 0 then -- FIXME: just messing around.
|
||||
if cfg.graycode then
|
||||
--local precision = 1 / cfg.deviation
|
||||
--print(cfg.deviation, precision)
|
||||
|
@ -569,10 +605,17 @@ local function learn_from_epoch()
|
|||
|
||||
local step = nn.zeros(#base_params)
|
||||
|
||||
-- new stuff
|
||||
|
||||
local delta_rewards -- only used for logging.
|
||||
local best_rewards
|
||||
if cfg.negate_trials then
|
||||
delta_rewards = {}
|
||||
for i = 1, cfg.epoch_trials do
|
||||
local ind = (i - 1) * 2 + 1
|
||||
local pos = trial_rewards[ind + 0]
|
||||
local neg = trial_rewards[ind + 1]
|
||||
delta_rewards[i] = abs(pos - neg)
|
||||
end
|
||||
|
||||
-- select one (the best) reward of each pos/neg pair.
|
||||
best_rewards = {}
|
||||
for i = 1, cfg.epoch_trials do
|
||||
|
@ -602,12 +645,12 @@ local function learn_from_epoch()
|
|||
end
|
||||
--print("top:", top_rewards)
|
||||
|
||||
local delta_rewards = {} -- only used for printing.
|
||||
local top_delta_rewards = {} -- only used for printing.
|
||||
for i, ind in ipairs(indices) do
|
||||
local sind = (ind - 1) * 2 + 1
|
||||
delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1])
|
||||
top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1])
|
||||
end
|
||||
print("best deltas:", delta_rewards)
|
||||
print("best deltas:", top_delta_rewards)
|
||||
|
||||
local _, reward_dev = calc_mean_dev(top_rewards)
|
||||
--print("mean, dev:", _, reward_dev)
|
||||
|
@ -628,14 +671,59 @@ local function learn_from_epoch()
|
|||
end
|
||||
|
||||
local step_mean, step_dev = calc_mean_dev(step)
|
||||
if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end
|
||||
--print("step stddev:", step_dev)
|
||||
print("full step stddev:", cfg.learning_rate * step_dev)
|
||||
print("step mean:", step_mean)
|
||||
print("step stddev:", step_dev)
|
||||
--print("full step stddev:", cfg.learning_rate * step_dev)
|
||||
|
||||
local momstep_mean, momstep_dev
|
||||
if cfg.adamant then
|
||||
if mom1 == nil then mom1 = nn.zeros(#step) end
|
||||
if mom2 == nil then mom2 = nn.zeros(#step) end
|
||||
if mom2max == nil then mom2max = nn.zeros(#step) end
|
||||
|
||||
local b1_t = pow(cfg.adam_b1, epoch_i)
|
||||
local b2_t = pow(cfg.adam_b2, epoch_i)
|
||||
|
||||
-- NOTE: with LuaJIT, splitting this loop would
|
||||
-- almost certainly be faster.
|
||||
for i, v in ipairs(step) do
|
||||
mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v
|
||||
mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v
|
||||
mom2max[i] = max(mom2[i], mom2max[i])
|
||||
if cfg.adam_debias then
|
||||
local num = (mom1[i] / (1 - b1_t))
|
||||
local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps
|
||||
step[i] = num / den
|
||||
else
|
||||
step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps)
|
||||
end
|
||||
end
|
||||
|
||||
momstep_mean, momstep_dev = calc_mean_dev(step)
|
||||
print("amsgrad mean:", momstep_mean)
|
||||
print("amsgrad stddev:", momstep_dev)
|
||||
end
|
||||
|
||||
for i, v in ipairs(base_params) do
|
||||
base_params[i] = v + cfg.learning_rate * step[i] - cfg.weight_decay * v
|
||||
end
|
||||
|
||||
local trial_mean, trial_std = calc_mean_dev(trial_rewards)
|
||||
local delta_mean, delta_std = calc_mean_dev(delta_rewards)
|
||||
local weight_mean, weight_std = calc_mean_dev(base_params)
|
||||
|
||||
log_csv{
|
||||
epoch = epoch_i,
|
||||
trial_mean = trial_mean,
|
||||
trial_std = trial_std,
|
||||
delta_mean = delta_mean,
|
||||
delta_std = delta_std,
|
||||
step_std = step_dev,
|
||||
adam_std = momstep_dev,
|
||||
weight_mean = weight_mean,
|
||||
weight_std = weight_std,
|
||||
}
|
||||
|
||||
if cfg.enable_network then
|
||||
network:distribute(base_params)
|
||||
network:save()
|
||||
|
|
Loading…
Reference in a new issue