From 3030e83d00873e7c8e9def4930750092f9e479f8 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Mon, 14 May 2018 01:34:08 +0200 Subject: [PATCH] refactor learn_from_epoch --- main.lua | 208 +++++++++++++++++++++++++++++++------------------------ 1 file changed, 119 insertions(+), 89 deletions(-) diff --git a/main.lua b/main.lua index 6744f75..a334199 100644 --- a/main.lua +++ b/main.lua @@ -48,6 +48,7 @@ local last_trial_state -- localize some stuff. +local assert = assert local print = print local ipairs = ipairs local pairs = pairs @@ -79,6 +80,7 @@ local arshift = bit.arshift local rol = bit.rol local ror = bit.ror +local emu = emu local gui = gui local util = require("util") @@ -177,8 +179,6 @@ local function prepare_epoch() for i = 1, cfg.epoch_trials do local noise = nn.zeros(#base_params) if cfg.graycode then - --local precision = 1 / cfg.deviation - --print(cfg.deviation, precision) for j = 1, #base_params do noise[j] = exp(-precision * nn.uniform()) end @@ -244,6 +244,107 @@ local function load_next_trial() network:distribute(W) end +local function collect_best_indices() + -- select one (the best) reward of each pos/neg pair. + local best_rewards = {} + if cfg.negate_trials then + for i = 1, cfg.epoch_trials do + local ind = (i - 1) * 2 + 1 + local pos = trial_rewards[ind + 0] + local neg = trial_rewards[ind + 1] + best_rewards[i] = max(pos, neg) + end + else + best_rewards = copy(trial_rewards) + end + + local indices = {} + for i = 1, #best_rewards do indices[i] = i end + sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end) + + for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end + return indices +end + +local function kinda_lipschitz(dir, pos, neg, mid) + local _, dev = calc_mean_dev(dir) + local c0 = neg - mid + local c1 = pos - mid + local l0 = abs(3 * c1 + c0) + local l1 = abs(c1 + 3 * c0) + return max(l0, l1) / (2 * dev) +end + +local function make_step_paired(rewards, current_cost) + local step = nn.zeros(#base_params) + local _, reward_dev = calc_mean_dev(rewards) + if reward_dev == 0 then reward_dev = 1 end + + for i = 1, cfg.epoch_trials do + local ind = (i - 1) * 2 + 1 + local pos = rewards[ind + 0] + local neg = rewards[ind + 1] + local reward = pos - neg + if reward ~= 0 then + local noise = trial_noise[i] + + if cfg.ars_lips then + local lips = kinda_lipschitz(noise, pos, neg, current_cost) + reward = reward / lips / cfg.deviation + else + reward = reward / reward_dev + end + + for j, v in ipairs(noise) do + step[j] = step[j] + reward * v / cfg.epoch_top_trials + end + end + end + return step +end + +local function make_step(rewards) + local step = nn.zeros(#base_params) + local _, reward_dev = calc_mean_dev(rewards) + if reward_dev == 0 then reward_dev = 1 end + + for i = 1, cfg.epoch_trials do + local reward = rewards[i] / reward_dev + if reward ~= 0 then + local noise = trial_noise[i] + + for j, v in ipairs(noise) do + step[j] = step[j] + reward * v / cfg.epoch_top_trials + end + end + end + return step +end + +local function amsgrad(step) -- in-place! + if mom1 == nil then mom1 = nn.zeros(#step) end + if mom2 == nil then mom2 = nn.zeros(#step) end + if mom2max == nil then mom2max = nn.zeros(#step) end + + local b1_t = pow(cfg.adam_b1, epoch_i) + local b2_t = pow(cfg.adam_b2, epoch_i) + + -- NOTE: with LuaJIT, splitting this loop would + -- almost certainly be faster. + for i, v in ipairs(step) do + mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v + mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v + mom2max[i] = max(mom2[i], mom2max[i]) + if cfg.adam_debias then + local num = (mom1[i] / (1 - b1_t)) + local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps + step[i] = num / den + else + step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps) + end + end +end + local function learn_from_epoch() print() --print('rewards:', trial_rewards) @@ -259,37 +360,17 @@ local function learn_from_epoch() print(("test trial: %d out of %d"):format(nth_place, #trial_rewards)) end - local step = nn.zeros(#base_params) - - local delta_rewards -- only used for logging. - local best_rewards + local delta_rewards = {} -- only used for logging. if cfg.negate_trials then - delta_rewards = {} for i = 1, cfg.epoch_trials do local ind = (i - 1) * 2 + 1 local pos = trial_rewards[ind + 0] local neg = trial_rewards[ind + 1] delta_rewards[i] = abs(pos - neg) end - - -- select one (the best) reward of each pos/neg pair. - best_rewards = {} - for i = 1, cfg.epoch_trials do - local ind = (i - 1) * 2 + 1 - local pos = trial_rewards[ind + 0] - local neg = trial_rewards[ind + 1] - best_rewards[i] = max(pos, neg) - end - else - best_rewards = copy(trial_rewards) end - local indices = {} - for i = 1, #best_rewards do indices[i] = i end - sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end) - - --print("indices:", indices) - for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end + local indices = collect_best_indices() print("best trials:", indices) local top_rewards = {} @@ -301,77 +382,29 @@ local function learn_from_epoch() end --print("top:", top_rewards) - local top_delta_rewards = {} -- only used for printing. - for i, ind in ipairs(indices) do - local sind = (ind - 1) * 2 + 1 - top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1]) - end - print("best deltas:", top_delta_rewards) - - if not cfg.ars_lips then - local _, reward_dev = calc_mean_dev(top_rewards) - --print("mean, dev:", _, reward_dev) - if reward_dev == 0 then reward_dev = 1 end - for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end - end - - for i = 1, cfg.epoch_trials do - local ind = (i - 1) * 2 + 1 - local pos = top_rewards[ind + 0] - local neg = top_rewards[ind + 1] - local reward = pos - neg - if reward ~= 0 then - local noise = trial_noise[i] - - if cfg.ars_lips then - local _, dev = calc_mean_dev(noise) - local c0 = neg - current_cost - local c1 = pos - current_cost - local l0 = abs(3 * c1 + c0) - local l1 = abs(c1 + 3 * c0) - local lips = max(l0, l1) / (2 * dev) - --reward = pos / lips - neg / lips - local old_reward = reward - reward = reward / lips - reward = reward / cfg.deviation -- FIXME: hack? - --print(("trial %i reward: %.0f -> %.5f"):format(i, old_reward, reward)) - end - - for j, v in ipairs(noise) do - step[j] = step[j] + reward * v / cfg.epoch_top_trials - end + if cfg.negate_trials then + local top_delta_rewards = {} -- only used for printing. + for i, ind in ipairs(indices) do + local sind = (ind - 1) * 2 + 1 + top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1]) end + print("best deltas:", top_delta_rewards) + end + + local step + if cfg.negate_trials then + step = make_step_paired(top_rewards, current_cost) + else + step = make_step(top_rewards) end local step_mean, step_dev = calc_mean_dev(step) print("step mean:", step_mean) print("step stddev:", step_dev) - --print("full step stddev:", cfg.learning_rate * step_dev) local momstep_mean, momstep_dev = 0, 0 if cfg.adamant then - if mom1 == nil then mom1 = nn.zeros(#step) end - if mom2 == nil then mom2 = nn.zeros(#step) end - if mom2max == nil then mom2max = nn.zeros(#step) end - - local b1_t = pow(cfg.adam_b1, epoch_i) - local b2_t = pow(cfg.adam_b2, epoch_i) - - -- NOTE: with LuaJIT, splitting this loop would - -- almost certainly be faster. - for i, v in ipairs(step) do - mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v - mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v - mom2max[i] = max(mom2[i], mom2max[i]) - if cfg.adam_debias then - local num = (mom1[i] / (1 - b1_t)) - local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps - step[i] = num / den - else - step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps) - end - end - + amsgrad(step) momstep_mean, momstep_dev = calc_mean_dev(step) print("amsgrad mean:", momstep_mean) print("amsgrad stddev:", momstep_dev) @@ -480,11 +513,8 @@ local function do_reset() game.W(0x756, 1) end - --max_time = min(log(epoch_i) * 10 + 100, cfg.cap_time) - --max_time = min(8 * sqrt(360 / cfg.epoch_trials * (epoch_i - 1)) + 100, cfg.cap_time) max_time = min(6 * sqrt(480 / cfg.epoch_trials * (epoch_i - 1)) + 60, cfg.cap_time) max_time = ceil(max_time) - --max_time = cfg.cap_time if once then savestate.load(startsave)