add AMSgrad optimizer and logging
This commit is contained in:
parent
c7c657513e
commit
90922a2bc3
2 changed files with 122 additions and 19 deletions
35
config.lua
35
config.lua
|
@ -2,7 +2,19 @@ local function approx_cossim(dim)
|
||||||
return math.pow(1.521 * dim - 0.521, -0.5026)
|
return math.pow(1.521 * dim - 0.521, -0.5026)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function intmap(x)
|
||||||
|
-- 0 -> 1.0
|
||||||
|
-- -1 -> 0.316
|
||||||
|
-- -2 -> 0.1
|
||||||
|
-- -3 -> 0.0316
|
||||||
|
-- -4 -> 0.01
|
||||||
|
-- etc.
|
||||||
|
return math.pow(10, x / 2)
|
||||||
|
end
|
||||||
|
|
||||||
local cfg = {
|
local cfg = {
|
||||||
|
log_fn = 'log.csv', -- can be nil to disable logging.
|
||||||
|
|
||||||
defer_prints = true,
|
defer_prints = true,
|
||||||
|
|
||||||
playable_mode = false,
|
playable_mode = false,
|
||||||
|
@ -16,20 +28,23 @@ local cfg = {
|
||||||
det_epsilon = false, -- take random actions with probability eps.
|
det_epsilon = false, -- take random actions with probability eps.
|
||||||
|
|
||||||
graycode = false,
|
graycode = false,
|
||||||
epoch_trials = 5,
|
epoch_trials = 64 * (7/8),
|
||||||
epoch_top_trials = 2, -- new with ARS.
|
epoch_top_trials = 40 * (7/8), -- new with ARS.
|
||||||
unperturbed_trial = false, -- do a trial without any noise.
|
unperturbed_trial = true, -- do a trial without any noise.
|
||||||
negate_trials = true, -- try pairs of normal and negated noise directions.
|
negate_trials = true, -- try pairs of normal and negated noise directions.
|
||||||
time_inputs = true, -- binary inputs of global frame count
|
time_inputs = true, -- binary inputs of global frame count
|
||||||
-- ^ note that this now doubles the effective trials.
|
-- ^ note that this now doubles the effective trials.
|
||||||
deviation = 0.32,
|
deviation = intmap(-3),
|
||||||
--learning_rate = 0.01 / approx_cossim(7051)
|
learning_rate = intmap(-4),
|
||||||
learning_rate = 0.32,
|
weight_decay = intmap(-6),
|
||||||
--learning_rate = 0.0032 / approx_cossim(66573)
|
|
||||||
--learning_rate = 0.0056 / approx_cossim(66573)
|
|
||||||
weight_decay = 0.0032,
|
|
||||||
|
|
||||||
cap_time = 200, --400
|
adamant = true, -- run steps through AMSgrad.
|
||||||
|
adam_b1 = math.pow(10, -1 / 15),
|
||||||
|
adam_b2 = math.pow(10, -1 / 100),
|
||||||
|
adam_eps = intmap(-8),
|
||||||
|
adam_debias = false,
|
||||||
|
|
||||||
|
cap_time = 222, --400
|
||||||
timer_loser = 1/2,
|
timer_loser = 1/2,
|
||||||
decrement_reward = false, -- bad idea, encourages mario to kill himself
|
decrement_reward = false, -- bad idea, encourages mario to kill himself
|
||||||
|
|
||||||
|
|
106
main.lua
106
main.lua
|
@ -16,6 +16,9 @@ local trial_neg = true
|
||||||
local trial_noise = {}
|
local trial_noise = {}
|
||||||
local trial_rewards = {}
|
local trial_rewards = {}
|
||||||
local trials_remaining = 0
|
local trials_remaining = 0
|
||||||
|
local mom1 -- first moments in AMSgrad.
|
||||||
|
local mom2 -- second moments in AMSgrad.
|
||||||
|
local mom2max -- running element-wise maximum of mom2.
|
||||||
|
|
||||||
local trial_frames = 0
|
local trial_frames = 0
|
||||||
local total_frames = 0
|
local total_frames = 0
|
||||||
|
@ -55,6 +58,7 @@ local print = print
|
||||||
local ipairs = ipairs
|
local ipairs = ipairs
|
||||||
local pairs = pairs
|
local pairs = pairs
|
||||||
local select = select
|
local select = select
|
||||||
|
local open = io.open
|
||||||
local abs = math.abs
|
local abs = math.abs
|
||||||
local floor = math.floor
|
local floor = math.floor
|
||||||
local ceil = math.ceil
|
local ceil = math.ceil
|
||||||
|
@ -88,6 +92,39 @@ local gui = gui
|
||||||
|
|
||||||
-- utilities.
|
-- utilities.
|
||||||
|
|
||||||
|
local log_map = {
|
||||||
|
epoch = 1,
|
||||||
|
trial_mean = 2,
|
||||||
|
trial_std = 3,
|
||||||
|
delta_mean = 4,
|
||||||
|
delta_std = 5,
|
||||||
|
step_std = 6,
|
||||||
|
adam_std = 7,
|
||||||
|
weight_mean = 8,
|
||||||
|
weight_std = 9,
|
||||||
|
}
|
||||||
|
|
||||||
|
local function log_csv(t)
|
||||||
|
if cfg.log_fn == nil then return end
|
||||||
|
local f = open(cfg.log_fn, 'a')
|
||||||
|
if f == nil then error("Failed to open log file "..cfg.log_fn) end
|
||||||
|
local values = {}
|
||||||
|
for k, v in pairs(t) do
|
||||||
|
local i = log_map[k]
|
||||||
|
if i == nil then error("Unexpected log key "..tostring(k)) end
|
||||||
|
values[i] = v
|
||||||
|
end
|
||||||
|
for k, i in pairs(log_map) do
|
||||||
|
if values[i] == nil then error("Missing log key "..tostring(k)) end
|
||||||
|
end
|
||||||
|
for i, v in ipairs(values) do
|
||||||
|
f:write(tostring(v))
|
||||||
|
if i ~= #values then f:write(",") end
|
||||||
|
end
|
||||||
|
f:write('\n')
|
||||||
|
f:close()
|
||||||
|
end
|
||||||
|
|
||||||
local function boolean_xor(a, b)
|
local function boolean_xor(a, b)
|
||||||
if a and b then return false end
|
if a and b then return false end
|
||||||
if not a and not b then return false end
|
if not a and not b then return false end
|
||||||
|
@ -444,7 +481,6 @@ local function prepare_epoch()
|
||||||
local noise = nn.zeros(#base_params)
|
local noise = nn.zeros(#base_params)
|
||||||
-- NOTE: change in implementation: deviation is multiplied here
|
-- NOTE: change in implementation: deviation is multiplied here
|
||||||
-- and ONLY here now.
|
-- and ONLY here now.
|
||||||
--if i % 2 == 0 then -- FIXME: just messing around.
|
|
||||||
if cfg.graycode then
|
if cfg.graycode then
|
||||||
--local precision = 1 / cfg.deviation
|
--local precision = 1 / cfg.deviation
|
||||||
--print(cfg.deviation, precision)
|
--print(cfg.deviation, precision)
|
||||||
|
@ -569,10 +605,17 @@ local function learn_from_epoch()
|
||||||
|
|
||||||
local step = nn.zeros(#base_params)
|
local step = nn.zeros(#base_params)
|
||||||
|
|
||||||
-- new stuff
|
local delta_rewards -- only used for logging.
|
||||||
|
|
||||||
local best_rewards
|
local best_rewards
|
||||||
if cfg.negate_trials then
|
if cfg.negate_trials then
|
||||||
|
delta_rewards = {}
|
||||||
|
for i = 1, cfg.epoch_trials do
|
||||||
|
local ind = (i - 1) * 2 + 1
|
||||||
|
local pos = trial_rewards[ind + 0]
|
||||||
|
local neg = trial_rewards[ind + 1]
|
||||||
|
delta_rewards[i] = abs(pos - neg)
|
||||||
|
end
|
||||||
|
|
||||||
-- select one (the best) reward of each pos/neg pair.
|
-- select one (the best) reward of each pos/neg pair.
|
||||||
best_rewards = {}
|
best_rewards = {}
|
||||||
for i = 1, cfg.epoch_trials do
|
for i = 1, cfg.epoch_trials do
|
||||||
|
@ -602,12 +645,12 @@ local function learn_from_epoch()
|
||||||
end
|
end
|
||||||
--print("top:", top_rewards)
|
--print("top:", top_rewards)
|
||||||
|
|
||||||
local delta_rewards = {} -- only used for printing.
|
local top_delta_rewards = {} -- only used for printing.
|
||||||
for i, ind in ipairs(indices) do
|
for i, ind in ipairs(indices) do
|
||||||
local sind = (ind - 1) * 2 + 1
|
local sind = (ind - 1) * 2 + 1
|
||||||
delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1])
|
top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1])
|
||||||
end
|
end
|
||||||
print("best deltas:", delta_rewards)
|
print("best deltas:", top_delta_rewards)
|
||||||
|
|
||||||
local _, reward_dev = calc_mean_dev(top_rewards)
|
local _, reward_dev = calc_mean_dev(top_rewards)
|
||||||
--print("mean, dev:", _, reward_dev)
|
--print("mean, dev:", _, reward_dev)
|
||||||
|
@ -628,14 +671,59 @@ local function learn_from_epoch()
|
||||||
end
|
end
|
||||||
|
|
||||||
local step_mean, step_dev = calc_mean_dev(step)
|
local step_mean, step_dev = calc_mean_dev(step)
|
||||||
if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end
|
print("step mean:", step_mean)
|
||||||
--print("step stddev:", step_dev)
|
print("step stddev:", step_dev)
|
||||||
print("full step stddev:", cfg.learning_rate * step_dev)
|
--print("full step stddev:", cfg.learning_rate * step_dev)
|
||||||
|
|
||||||
|
local momstep_mean, momstep_dev
|
||||||
|
if cfg.adamant then
|
||||||
|
if mom1 == nil then mom1 = nn.zeros(#step) end
|
||||||
|
if mom2 == nil then mom2 = nn.zeros(#step) end
|
||||||
|
if mom2max == nil then mom2max = nn.zeros(#step) end
|
||||||
|
|
||||||
|
local b1_t = pow(cfg.adam_b1, epoch_i)
|
||||||
|
local b2_t = pow(cfg.adam_b2, epoch_i)
|
||||||
|
|
||||||
|
-- NOTE: with LuaJIT, splitting this loop would
|
||||||
|
-- almost certainly be faster.
|
||||||
|
for i, v in ipairs(step) do
|
||||||
|
mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v
|
||||||
|
mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v
|
||||||
|
mom2max[i] = max(mom2[i], mom2max[i])
|
||||||
|
if cfg.adam_debias then
|
||||||
|
local num = (mom1[i] / (1 - b1_t))
|
||||||
|
local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps
|
||||||
|
step[i] = num / den
|
||||||
|
else
|
||||||
|
step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
momstep_mean, momstep_dev = calc_mean_dev(step)
|
||||||
|
print("amsgrad mean:", momstep_mean)
|
||||||
|
print("amsgrad stddev:", momstep_dev)
|
||||||
|
end
|
||||||
|
|
||||||
for i, v in ipairs(base_params) do
|
for i, v in ipairs(base_params) do
|
||||||
base_params[i] = v + cfg.learning_rate * step[i] - cfg.weight_decay * v
|
base_params[i] = v + cfg.learning_rate * step[i] - cfg.weight_decay * v
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local trial_mean, trial_std = calc_mean_dev(trial_rewards)
|
||||||
|
local delta_mean, delta_std = calc_mean_dev(delta_rewards)
|
||||||
|
local weight_mean, weight_std = calc_mean_dev(base_params)
|
||||||
|
|
||||||
|
log_csv{
|
||||||
|
epoch = epoch_i,
|
||||||
|
trial_mean = trial_mean,
|
||||||
|
trial_std = trial_std,
|
||||||
|
delta_mean = delta_mean,
|
||||||
|
delta_std = delta_std,
|
||||||
|
step_std = step_dev,
|
||||||
|
adam_std = momstep_dev,
|
||||||
|
weight_mean = weight_mean,
|
||||||
|
weight_std = weight_std,
|
||||||
|
}
|
||||||
|
|
||||||
if cfg.enable_network then
|
if cfg.enable_network then
|
||||||
network:distribute(base_params)
|
network:distribute(base_params)
|
||||||
network:save()
|
network:save()
|
||||||
|
|
Loading…
Reference in a new issue