diff --git a/README.txt b/README.txt index f0dd665..4472474 100644 --- a/README.txt +++ b/README.txt @@ -3,6 +3,7 @@ please be mindful when sharing it. however, feel free to copy any snippets of code you find useful. TODOs: (that i can remember right now) +- normalize `for i=a,b` code style - normalize and/or embed sprite type inputs - settle on a network architecture - compute how many input neurons the network needs instead of hardcoding diff --git a/ars.lua b/ars.lua new file mode 100644 index 0000000..cc9431a --- /dev/null +++ b/ars.lua @@ -0,0 +1,209 @@ +-- Augmented Random Search +-- https://arxiv.org/abs/1803.07055 +-- with some tweaks (lips) by myself. + +local abs = math.abs +local floor = math.floor +local ipairs = ipairs +local max = math.max +local print = print + +local Base = require "Base" + +local nn = require "nn" +local normal = nn.normal +local prod = nn.prod +local zeros = nn.zeros + +local util = require "util" +local argsort = util.argsort +local calc_mean_dev = util.calc_mean_dev + +local Ars = Base:extend() + +local function collect_best_indices(scored, top, antithetic) + -- select one (the best) reward of each pos/neg pair. + local best_rewards + if antithetic then + best_rewards = {} + for i = 1, #scored, 2 do + local ind = floor(i / 2) + 1 + local pos = scored[i + 0] + local neg = scored[i + 1] + best_rewards[ind] = max(pos, neg) + end + else + best_rewards = scored + end + + local indices = argsort(best_rewards, function(a, b) return a > b end) + + for i = top + 1, #best_rewards do indices[i] = nil end + return indices +end + +local function kinda_lipschitz(dir, pos, neg, mid) + local _, dev = calc_mean_dev(dir) + local c0 = neg - mid + local c1 = pos - mid + local l0 = abs(3 * c1 + c0) + local l1 = abs(c1 + 3 * c0) + return max(l0, l1) / (2 * dev) +end + +local function amsgrad(step) -- in-place! -- TODO: fix this. + if mom1 == nil then mom1 = nn.zeros(#step) end + if mom2 == nil then mom2 = nn.zeros(#step) end + if mom2max == nil then mom2max = nn.zeros(#step) end + + local b1_t = pow(cfg.adam_b1, epoch_i) + local b2_t = pow(cfg.adam_b2, epoch_i) + + -- NOTE: with LuaJIT, splitting this loop would + -- almost certainly be faster. + for i, v in ipairs(step) do + mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v + mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v + mom2max[i] = max(mom2[i], mom2max[i]) + if cfg.adam_debias then + local num = (mom1[i] / (1 - b1_t)) + local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps + step[i] = num / den + else + step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps) + end + end +end + +function Ars:init(dims, popsize, poptop, learning_rate, sigma, antithetic) + self.dims = dims + self.popsize = popsize or 4 + (3 * floor(log(dims))) + self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) + self.sigma = sigma or 1 + self.antithetic = antithetic and true or false + + self.poptop = poptop or popsize + assert(self.poptop <= popsize) + if self.antithetic then self.popsize = self.popsize * 2 end + + self._params = nn.zeros(self.dims) +end + +function Ars:params(new_params) + if new_params ~= nil then + assert(#self._params == #new_params, "new parameters have the wrong size") + for i, v in ipairs(new_params) do self._params[i] = v end + end + return self._params +end + +function Ars:ask(graycode) + local asked = {} + local noise = {} + + for i = 1, self.popsize do + local asking = zeros(self.dims) + local noisy = zeros(self.dims) + asked[i] = asking + + if self.antithetic and i % 2 == 0 then + for j, v in ipairs(self._params) do + asking[i] = v - noisy[j] + end + + else + if graycode ~= nil then + for j = 1, self.dims do + noisy[j] = exp(-precision * nn.uniform()) + end + for j = 1, self.dims do + noisy[j] = nn.uniform() < 0.5 and noisy[j] or -noisy[j] + end + else + for j = 1, self.dims do + noisy[j] = self.sigma * nn.normal() + end + end + + for j, v in ipairs(self._params) do + asking[j] = v + noisy[j] + end + end + + noise[i] = noisy + end + + self.noise = noise + return asked, noise +end + +function Ars:tell(scored, unperturbed_score) + local indices = collect_best_indices(scored, self.poptop, self.antithetic) + --print("best trials:", indices) + + local top_rewards = {} + for i = 1, #scored do top_rewards[i] = 0 end + for _, ind in ipairs(indices) do + local sind = (ind - 1) * 2 + 1 + top_rewards[sind + 0] = scored[sind + 0] + top_rewards[sind + 1] = scored[sind + 1] + end + --print("top:", top_rewards) + + if self.antithetic then + local top_delta_rewards = {} -- only used for printing. + for i, ind in ipairs(indices) do + local sind = (ind - 1) * 2 + 1 + top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1]) + end + --print("best deltas:", top_delta_rewards) + end + + local step = nn.zeros(self.dims) + local _, reward_dev = calc_mean_dev(top_rewards) + if reward_dev == 0 then reward_dev = 1 end + + if self.antithetic then + for i = 1, floor(self.popsize / 2) do + local ind = (i - 1) * 2 + 1 + local pos = top_rewards[ind + 0] + local neg = top_rewards[ind + 1] + local reward = pos - neg + if reward ~= 0 then + local noisy = self.noise[i] + + if unperturbed_score ~= nil then + local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score) + reward = reward / lips / self.sigma + else + reward = reward / reward_dev + end + + for j, v in ipairs(noisy) do + step[j] = step[j] + reward * v / self.poptop + end + end + end + else + for i = 1, self.popsize do + local reward = top_rewards[i] / reward_dev + if reward ~= 0 then + local noisy = self.noise[i] + + for j, v in ipairs(noisy) do + step[j] = step[j] + reward * v / self.poptop + end + end + end + end + + for i, v in ipairs(self._params) do + self._params[i] = v + self.learning_rate * step[i] + end + + self.asked = nil +end + +return { + Ars = Ars, +} diff --git a/config.lua b/config.lua index 1799023..a3b00ba 100644 --- a/config.lua +++ b/config.lua @@ -28,10 +28,11 @@ local common_cfg = { graycode = false, unperturbed_trial = true, -- do a trial without any noise. negate_trials = true, -- try pairs of normal and negated noise directions. - -- ^ note that this now doubles the effective trials. + -- AKA antithetic sampling. note that this doubles the number of trials. time_inputs = true, -- binary inputs of global frame count normalize_inputs = false, + es = 'ars', ars_lips = false, adamant = false, -- run steps through AMSgrad. adam_b1 = math.pow(10, -1 / 1), -- fewer trials, more momentum! @@ -90,4 +91,7 @@ assert(not cfg.ars_lips or cfg.unperturbed_trial, assert(not cfg.ars_lips or cfg.negate_trials, "cfg.negate_trials must be true to use cfg.ars_lips") +assert(not cfg.adamant, + "cfg.adamant not yet re-implemented") + return cfg diff --git a/main.lua b/main.lua index 82b4b64..e51e169 100644 --- a/main.lua +++ b/main.lua @@ -11,12 +11,13 @@ local epoch_i = 0 local base_params local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled. local trial_neg = true -local trial_noise = {} +local trial_params --= {} local trial_rewards = {} local trials_remaining = 0 local mom1 -- first moments in AMSgrad. local mom2 -- second moments in AMSgrad. local mom2max -- running element-wise maximum of mom2. +local es -- evolution strategy. local trial_frames = 0 local total_frames = 0 @@ -35,7 +36,6 @@ local jp local screen_scroll_delta local reward ---local all_rewards = {} local powerup_old local status_old @@ -172,191 +172,50 @@ end -- learning and evaluation. +local ars = require("ars") + local function prepare_epoch() + trial_neg = false + base_params = network:collect() if cfg.playback_mode then return end print('preparing epoch '..tostring(epoch_i)..'.') - empty(trial_noise) empty(trial_rewards) - local precision = (pow(cfg.deviation, 1/-0.51175585) - 8.68297257) / 1.66484392 + local precision if cfg.graycode then + precision = (pow(cfg.deviation, 1/-0.51175585) - 8.68297257) / 1.66484392 print(("chosen precision: %.2f"):format(precision)) end - for i = 1, cfg.epoch_trials do - local noise = nn.zeros(#base_params) - if cfg.graycode then - for j = 1, #base_params do - noise[j] = exp(-precision * nn.uniform()) - end - for j = 1, #base_params do - noise[j] = nn.uniform() < 0.5 and noise[j] or -noise[j] - end - else - for j = 1, #base_params do - noise[j] = cfg.deviation * nn.normal() - end - end - trial_noise[i] = noise + local dummy + if es == 'ars' then + trial_params, dummy = es:ask(precision) + else + trial_params, dummy = es:ask() end + trial_i = -1 end -local function load_next_pair() - trial_i = trial_i + 1 - if trial_i == 0 and not cfg.unperturbed_trial then - trial_i = 1 - trial_neg = true - end - - local W = copy(base_params) - - if trial_i > 0 then - if trial_neg then - local noise = trial_noise[trial_i] - for i, v in ipairs(base_params) do - W[i] = v + noise[i] - end - - else - trial_i = trial_i - 1 - local noise = trial_noise[trial_i] - for i, v in ipairs(base_params) do - W[i] = v - noise[i] - end - end - - trial_neg = not trial_neg - end - - network:distribute(W) -end - local function load_next_trial() - if cfg.negate_trials then return load_next_pair() end + if cfg.negate_trials then trial_neg = not trial_neg end trial_i = trial_i + 1 - local W = copy(base_params) if trial_i == 0 and not cfg.unperturbed_trial then trial_i = 1 end if trial_i > 0 then - print('loading trial', trial_i) - local noise = trial_noise[trial_i] - for i, v in ipairs(base_params) do - W[i] = v + noise[i] - end + --print('loading trial', trial_i) + network:distribute(trial_params[trial_i]) else - print("test trial") - end - network:distribute(W) -end - -local function collect_best_indices() - -- select one (the best) reward of each pos/neg pair. - local best_rewards = {} - if cfg.negate_trials then - for i = 1, cfg.epoch_trials do - local ind = (i - 1) * 2 + 1 - local pos = trial_rewards[ind + 0] - local neg = trial_rewards[ind + 1] - best_rewards[i] = max(pos, neg) - end - else - best_rewards = copy(trial_rewards) - end - - local indices = argsort(best_rewards, function(a, b) return a > b end) - - for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end - return indices -end - -local function kinda_lipschitz(dir, pos, neg, mid) - local _, dev = calc_mean_dev(dir) - local c0 = neg - mid - local c1 = pos - mid - local l0 = abs(3 * c1 + c0) - local l1 = abs(c1 + 3 * c0) - return max(l0, l1) / (2 * dev) -end - -local function make_step_paired(rewards, current_cost) - local step = nn.zeros(#base_params) - local _, reward_dev = calc_mean_dev(rewards) - if reward_dev == 0 then reward_dev = 1 end - - for i = 1, cfg.epoch_trials do - local ind = (i - 1) * 2 + 1 - local pos = rewards[ind + 0] - local neg = rewards[ind + 1] - local reward = pos - neg - if reward ~= 0 then - local noise = trial_noise[i] - - if cfg.ars_lips then - local lips = kinda_lipschitz(noise, pos, neg, current_cost) - reward = reward / lips / cfg.deviation - else - reward = reward / reward_dev - end - - for j, v in ipairs(noise) do - step[j] = step[j] + reward * v / cfg.epoch_top_trials - end - end - end - return step -end - -local function make_step(rewards) - local step = nn.zeros(#base_params) - local _, reward_dev = calc_mean_dev(rewards) - if reward_dev == 0 then reward_dev = 1 end - - for i = 1, cfg.epoch_trials do - local reward = rewards[i] / reward_dev - if reward ~= 0 then - local noise = trial_noise[i] - - for j, v in ipairs(noise) do - step[j] = step[j] + reward * v / cfg.epoch_top_trials - end - end - end - return step -end - -local function amsgrad(step) -- in-place! - if mom1 == nil then mom1 = nn.zeros(#step) end - if mom2 == nil then mom2 = nn.zeros(#step) end - if mom2max == nil then mom2max = nn.zeros(#step) end - - local b1_t = pow(cfg.adam_b1, epoch_i) - local b2_t = pow(cfg.adam_b2, epoch_i) - - -- NOTE: with LuaJIT, splitting this loop would - -- almost certainly be faster. - for i, v in ipairs(step) do - mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v - mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v - mom2max[i] = max(mom2[i], mom2max[i]) - if cfg.adam_debias then - local num = (mom1[i] / (1 - b1_t)) - local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps - step[i] = num / den - else - step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps) - end + --print("test trial") + network:distribute(base_params) end end local function learn_from_epoch() print() - --print('rewards:', trial_rewards) - - --for _, v in ipairs(trial_rewards) do insert(all_rewards, v) end local current_cost = trial_rewards[0] -- may be nil! @@ -369,58 +228,45 @@ local function learn_from_epoch() local delta_rewards = {} -- only used for logging. if cfg.negate_trials then - for i = 1, cfg.epoch_trials do - local ind = (i - 1) * 2 + 1 - local pos = trial_rewards[ind + 0] - local neg = trial_rewards[ind + 1] - delta_rewards[i] = abs(pos - neg) + for i = 1, #trial_rewards, 2 do + local ind = floor(i / 2) + 1 + local pos = trial_rewards[i + 0] + local neg = trial_rewards[i + 1] + delta_rewards[ind] = abs(pos - neg) end end - local indices = collect_best_indices() - print("best trials:", indices) - - local top_rewards = {} - for i = 1, #trial_rewards do top_rewards[i] = 0 end - for _, ind in ipairs(indices) do - local sind = (ind - 1) * 2 + 1 - top_rewards[sind + 0] = trial_rewards[sind + 0] - top_rewards[sind + 1] = trial_rewards[sind + 1] - end - --print("top:", top_rewards) - - if cfg.negate_trials then - local top_delta_rewards = {} -- only used for printing. - for i, ind in ipairs(indices) do - local sind = (ind - 1) * 2 + 1 - top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1]) - end - print("best deltas:", top_delta_rewards) - end - - local step - if cfg.negate_trials then - step = make_step_paired(top_rewards, current_cost) + if es == 'ars' then + es:tell(trial_rewards, current_cost) else - step = make_step(top_rewards) + es:tell(trial_rewards) end + local step_mean, step_dev = 0, 0 + --[[ TODO local step_mean, step_dev = calc_mean_dev(step) print("step mean:", step_mean) print("step stddev:", step_dev) + --]] local momstep_mean, momstep_dev = 0, 0 + --[[ TODO if cfg.adamant then amsgrad(step) momstep_mean, momstep_dev = calc_mean_dev(step) print("amsgrad mean:", momstep_mean) print("amsgrad stddev:", momstep_dev) end + --]] + + base_params = es:params() for i, v in ipairs(base_params) do - base_params[i] = v + cfg.learning_rate * step[i] - cfg.weight_decay * v + base_params[i] = v * (1 - cfg.weight_decay) end + es:params(base_params) + local trial_mean, trial_std = calc_mean_dev(trial_rewards) local delta_mean, delta_std = calc_mean_dev(delta_rewards) local weight_mean, weight_std = calc_mean_dev(base_params) @@ -465,6 +311,7 @@ local function joypad_mash(button) end local function loadlevel(world, level) + -- TODO: move to smb.lua. rename to load_level. if world == 0 then world = random(1, 8) end if level == 0 then level = random(1, 4) end emu.poweron() @@ -499,7 +346,8 @@ local function do_reset() local pos = trial_rewards[#trial_rewards] local neg = reward local fmt = "trial %i rewards: %+i, %+i (%s, %s)" - print(fmt:format(trial_i, pos, neg, last_trial_state, state)) + print(fmt:format(floor(trial_i / 2), + pos, neg, last_trial_state, state)) end last_trial_state = state else @@ -517,7 +365,7 @@ local function do_reset() end end - if epoch_i == 0 or (trial_i == cfg.epoch_trials and trial_neg) then + if epoch_i == 0 or (trial_i == #trial_params and trial_neg) then if epoch_i > 0 then learn_from_epoch() end if not cfg.playback_mode then epoch_i = epoch_i + 1 end prepare_epoch() @@ -527,6 +375,11 @@ local function do_reset() end end + max_time = min(6 * sqrt(480 / #trial_params * (epoch_i - 1)) + 60, cfg.cap_time) + max_time = ceil(max_time) + + -- TODO: game.reset(cfg.starting_lives, cfg.start_big) + if game.get_state() == 'loading' then game.advance() end -- kind of a hack. reward = 0 powerup_old = game.R(0x754) @@ -543,8 +396,7 @@ local function do_reset() game.W(0x756, 1) end - max_time = min(6 * sqrt(480 / cfg.epoch_trials * (epoch_i - 1)) + 60, cfg.cap_time) - max_time = ceil(max_time) + -- end of game.reset() if state_saved then savestate.load(startsave) @@ -585,6 +437,13 @@ local function init() local res, err = pcall(network.load, network, cfg.params_fn) if res == false then print(err) end + + if cfg.es == 'ars' then + es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials, + cfg.learning_rate, cfg.deviation, cfg.negate_trials) + else + error("Unknown evolution strategy specified: " + tostring(cfg.es)) + end end local function prepare_reset()