diff --git a/config.lua b/config.lua index c7ce9b3..5821ef7 100644 --- a/config.lua +++ b/config.lua @@ -2,7 +2,19 @@ local function approx_cossim(dim) return math.pow(1.521 * dim - 0.521, -0.5026) end +local function intmap(x) + -- 0 -> 1.0 + -- -1 -> 0.316 + -- -2 -> 0.1 + -- -3 -> 0.0316 + -- -4 -> 0.01 + -- etc. + return math.pow(10, x / 2) +end + local cfg = { + log_fn = 'log.csv', -- can be nil to disable logging. + defer_prints = true, playable_mode = false, @@ -16,20 +28,23 @@ local cfg = { det_epsilon = false, -- take random actions with probability eps. graycode = false, - epoch_trials = 5, - epoch_top_trials = 2, -- new with ARS. - unperturbed_trial = false, -- do a trial without any noise. + epoch_trials = 64 * (7/8), + epoch_top_trials = 40 * (7/8), -- new with ARS. + unperturbed_trial = true, -- do a trial without any noise. negate_trials = true, -- try pairs of normal and negated noise directions. time_inputs = true, -- binary inputs of global frame count -- ^ note that this now doubles the effective trials. - deviation = 0.32, - --learning_rate = 0.01 / approx_cossim(7051) - learning_rate = 0.32, - --learning_rate = 0.0032 / approx_cossim(66573) - --learning_rate = 0.0056 / approx_cossim(66573) - weight_decay = 0.0032, + deviation = intmap(-3), + learning_rate = intmap(-4), + weight_decay = intmap(-6), - cap_time = 200, --400 + adamant = true, -- run steps through AMSgrad. + adam_b1 = math.pow(10, -1 / 15), + adam_b2 = math.pow(10, -1 / 100), + adam_eps = intmap(-8), + adam_debias = false, + + cap_time = 222, --400 timer_loser = 1/2, decrement_reward = false, -- bad idea, encourages mario to kill himself diff --git a/main.lua b/main.lua index 0c71453..d69b734 100644 --- a/main.lua +++ b/main.lua @@ -16,6 +16,9 @@ local trial_neg = true local trial_noise = {} local trial_rewards = {} local trials_remaining = 0 +local mom1 -- first moments in AMSgrad. +local mom2 -- second moments in AMSgrad. +local mom2max -- running element-wise maximum of mom2. local trial_frames = 0 local total_frames = 0 @@ -55,6 +58,7 @@ local print = print local ipairs = ipairs local pairs = pairs local select = select +local open = io.open local abs = math.abs local floor = math.floor local ceil = math.ceil @@ -88,6 +92,39 @@ local gui = gui -- utilities. +local log_map = { + epoch = 1, + trial_mean = 2, + trial_std = 3, + delta_mean = 4, + delta_std = 5, + step_std = 6, + adam_std = 7, + weight_mean = 8, + weight_std = 9, +} + +local function log_csv(t) + if cfg.log_fn == nil then return end + local f = open(cfg.log_fn, 'a') + if f == nil then error("Failed to open log file "..cfg.log_fn) end + local values = {} + for k, v in pairs(t) do + local i = log_map[k] + if i == nil then error("Unexpected log key "..tostring(k)) end + values[i] = v + end + for k, i in pairs(log_map) do + if values[i] == nil then error("Missing log key "..tostring(k)) end + end + for i, v in ipairs(values) do + f:write(tostring(v)) + if i ~= #values then f:write(",") end + end + f:write('\n') + f:close() +end + local function boolean_xor(a, b) if a and b then return false end if not a and not b then return false end @@ -444,7 +481,6 @@ local function prepare_epoch() local noise = nn.zeros(#base_params) -- NOTE: change in implementation: deviation is multiplied here -- and ONLY here now. - --if i % 2 == 0 then -- FIXME: just messing around. if cfg.graycode then --local precision = 1 / cfg.deviation --print(cfg.deviation, precision) @@ -569,10 +605,17 @@ local function learn_from_epoch() local step = nn.zeros(#base_params) - -- new stuff - + local delta_rewards -- only used for logging. local best_rewards if cfg.negate_trials then + delta_rewards = {} + for i = 1, cfg.epoch_trials do + local ind = (i - 1) * 2 + 1 + local pos = trial_rewards[ind + 0] + local neg = trial_rewards[ind + 1] + delta_rewards[i] = abs(pos - neg) + end + -- select one (the best) reward of each pos/neg pair. best_rewards = {} for i = 1, cfg.epoch_trials do @@ -602,12 +645,12 @@ local function learn_from_epoch() end --print("top:", top_rewards) - local delta_rewards = {} -- only used for printing. + local top_delta_rewards = {} -- only used for printing. for i, ind in ipairs(indices) do local sind = (ind - 1) * 2 + 1 - delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1]) + top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1]) end - print("best deltas:", delta_rewards) + print("best deltas:", top_delta_rewards) local _, reward_dev = calc_mean_dev(top_rewards) --print("mean, dev:", _, reward_dev) @@ -628,14 +671,59 @@ local function learn_from_epoch() end local step_mean, step_dev = calc_mean_dev(step) - if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end - --print("step stddev:", step_dev) - print("full step stddev:", cfg.learning_rate * step_dev) + print("step mean:", step_mean) + print("step stddev:", step_dev) + --print("full step stddev:", cfg.learning_rate * step_dev) + + local momstep_mean, momstep_dev + if cfg.adamant then + if mom1 == nil then mom1 = nn.zeros(#step) end + if mom2 == nil then mom2 = nn.zeros(#step) end + if mom2max == nil then mom2max = nn.zeros(#step) end + + local b1_t = pow(cfg.adam_b1, epoch_i) + local b2_t = pow(cfg.adam_b2, epoch_i) + + -- NOTE: with LuaJIT, splitting this loop would + -- almost certainly be faster. + for i, v in ipairs(step) do + mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v + mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v + mom2max[i] = max(mom2[i], mom2max[i]) + if cfg.adam_debias then + local num = (mom1[i] / (1 - b1_t)) + local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps + step[i] = num / den + else + step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps) + end + end + + momstep_mean, momstep_dev = calc_mean_dev(step) + print("amsgrad mean:", momstep_mean) + print("amsgrad stddev:", momstep_dev) + end for i, v in ipairs(base_params) do base_params[i] = v + cfg.learning_rate * step[i] - cfg.weight_decay * v end + local trial_mean, trial_std = calc_mean_dev(trial_rewards) + local delta_mean, delta_std = calc_mean_dev(delta_rewards) + local weight_mean, weight_std = calc_mean_dev(base_params) + + log_csv{ + epoch = epoch_i, + trial_mean = trial_mean, + trial_std = trial_std, + delta_mean = delta_mean, + delta_std = delta_std, + step_std = step_dev, + adam_std = momstep_dev, + weight_mean = weight_mean, + weight_std = weight_std, + } + if cfg.enable_network then network:distribute(base_params) network:save()