local globalize = require("strict") -- configuration. local cfg = require("config") local gcfg = require("gameconfig") -- state. local epoch_i = 0 local base_params local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled. local trial_neg = true local trial_noise = {} local trial_rewards = {} local trials_remaining = 0 local mom1 -- first moments in AMSgrad. local mom2 -- second moments in AMSgrad. local mom2max -- running element-wise maximum of mom2. local trial_frames = 0 local total_frames = 0 local lagless_count = 0 local force_start = false local force_start_old = false local startsave = savestate.create(1) local poketime = false local max_time local jp local screen_scroll_delta local reward --local all_rewards = {} local powerup_old local status_old local coins_old local score_old local once = false local reset = true local state_old = '' local last_trial_state -- localize some stuff. local assert = assert local print = print local ipairs = ipairs local pairs = pairs local select = select local open = io.open local abs = math.abs local floor = math.floor local ceil = math.ceil local min = math.min local max = math.max local exp = math.exp local pow = math.pow local log = math.log local sqrt = math.sqrt local random = math.random local randomseed = math.randomseed local insert = table.insert local remove = table.remove local unpack = table.unpack or unpack local sort = table.sort local band = bit.band local bor = bit.bor local bxor = bit.bxor local bnot = bit.bnot local lshift = bit.lshift local rshift = bit.rshift local arshift = bit.arshift local rol = bit.rol local ror = bit.ror local emu = emu local gui = gui local util = require("util") local argmax = util.argmax local argsort = util.argsort local calc_mean_dev = util.calc_mean_dev local clamp = util.clamp local copy = util.copy local empty = util.empty local lerp = util.lerp local softchoice = util.softchoice local unperturbed_rank = util.unperturbed_rank local game = require("smb") game.overlay = cfg.enable_overlay -- utilities. local log_map = { epoch = 1, trial_mean = 2, trial_std = 3, delta_mean = 4, delta_std = 5, step_std = 6, adam_std = 7, weight_mean = 8, weight_std = 9, test_trial = 10, } local function log_csv(t) if cfg.log_fn == nil then return end local f = open(cfg.log_fn, 'a') if f == nil then error("Failed to open log file "..cfg.log_fn) end local values = {} for k, v in pairs(t) do local i = log_map[k] if i == nil then error("Unexpected log key "..tostring(k)) end values[i] = v end for k, i in pairs(log_map) do if values[i] == nil then error("Missing log key "..tostring(k)) end end for i, v in ipairs(values) do f:write(tostring(v)) if i ~= #values then f:write(",") end end f:write('\n') f:close() end -- network parameters. package.loaded['nn'] = nil -- DEBUG local nn = require("nn") local network local nn_x, nn_tx, nn_ty, nn_y, nn_z local function make_network(input_size) nn_x = nn.Input({input_size}) nn_tx = nn.Input({gcfg.tile_count}) nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2)) nn_y = nn.Merge() nn_x:feed(nn_y) nn_ty:feed(nn_y) nn_y = nn_y:feed(nn.Dense(128)) if cfg.deterministic then nn_y = nn_y:feed(nn.Relu()) else nn_y = nn_y:feed(nn.Gelu()) end if cfg.layernorm then nn_y = nn_y:feed(nn.LayerNorm()) end nn_z = nn_y nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut)) nn_z = nn_z:feed(nn.Softmax()) return nn.Model({nn_x, nn_tx}, {nn_z}) end -- learning and evaluation. local function prepare_epoch() base_params = network:collect() if cfg.playback_mode then return end print('preparing epoch '..tostring(epoch_i)..'.') empty(trial_noise) empty(trial_rewards) local precision = (pow(cfg.deviation, 1/-0.51175585) - 8.68297257) / 1.66484392 if cfg.graycode then print(("chosen precision: %.2f"):format(precision)) end for i = 1, cfg.epoch_trials do local noise = nn.zeros(#base_params) if cfg.graycode then for j = 1, #base_params do noise[j] = exp(-precision * nn.uniform()) end for j = 1, #base_params do noise[j] = nn.uniform() < 0.5 and noise[j] or -noise[j] end else for j = 1, #base_params do noise[j] = cfg.deviation * nn.normal() end end trial_noise[i] = noise end trial_i = -1 end local function load_next_pair() trial_i = trial_i + 1 if trial_i == 0 and not cfg.unperturbed_trial then trial_i = 1 trial_neg = true end local W = copy(base_params) if trial_i > 0 then if trial_neg then local noise = trial_noise[trial_i] for i, v in ipairs(base_params) do W[i] = v + noise[i] end else trial_i = trial_i - 1 local noise = trial_noise[trial_i] for i, v in ipairs(base_params) do W[i] = v - noise[i] end end trial_neg = not trial_neg end network:distribute(W) end local function load_next_trial() if cfg.negate_trials then return load_next_pair() end trial_i = trial_i + 1 local W = copy(base_params) if trial_i == 0 and not cfg.unperturbed_trial then trial_i = 1 end if trial_i > 0 then print('loading trial', trial_i) local noise = trial_noise[trial_i] for i, v in ipairs(base_params) do W[i] = v + noise[i] end else print("test trial") end network:distribute(W) end local function collect_best_indices() -- select one (the best) reward of each pos/neg pair. local best_rewards = {} if cfg.negate_trials then for i = 1, cfg.epoch_trials do local ind = (i - 1) * 2 + 1 local pos = trial_rewards[ind + 0] local neg = trial_rewards[ind + 1] best_rewards[i] = max(pos, neg) end else best_rewards = copy(trial_rewards) end local indices = argsort(best_rewards, function(a, b) return a > b end) for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end return indices end local function kinda_lipschitz(dir, pos, neg, mid) local _, dev = calc_mean_dev(dir) local c0 = neg - mid local c1 = pos - mid local l0 = abs(3 * c1 + c0) local l1 = abs(c1 + 3 * c0) return max(l0, l1) / (2 * dev) end local function make_step_paired(rewards, current_cost) local step = nn.zeros(#base_params) local _, reward_dev = calc_mean_dev(rewards) if reward_dev == 0 then reward_dev = 1 end for i = 1, cfg.epoch_trials do local ind = (i - 1) * 2 + 1 local pos = rewards[ind + 0] local neg = rewards[ind + 1] local reward = pos - neg if reward ~= 0 then local noise = trial_noise[i] if cfg.ars_lips then local lips = kinda_lipschitz(noise, pos, neg, current_cost) reward = reward / lips / cfg.deviation else reward = reward / reward_dev end for j, v in ipairs(noise) do step[j] = step[j] + reward * v / cfg.epoch_top_trials end end end return step end local function make_step(rewards) local step = nn.zeros(#base_params) local _, reward_dev = calc_mean_dev(rewards) if reward_dev == 0 then reward_dev = 1 end for i = 1, cfg.epoch_trials do local reward = rewards[i] / reward_dev if reward ~= 0 then local noise = trial_noise[i] for j, v in ipairs(noise) do step[j] = step[j] + reward * v / cfg.epoch_top_trials end end end return step end local function amsgrad(step) -- in-place! if mom1 == nil then mom1 = nn.zeros(#step) end if mom2 == nil then mom2 = nn.zeros(#step) end if mom2max == nil then mom2max = nn.zeros(#step) end local b1_t = pow(cfg.adam_b1, epoch_i) local b2_t = pow(cfg.adam_b2, epoch_i) -- NOTE: with LuaJIT, splitting this loop would -- almost certainly be faster. for i, v in ipairs(step) do mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v mom2max[i] = max(mom2[i], mom2max[i]) if cfg.adam_debias then local num = (mom1[i] / (1 - b1_t)) local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps step[i] = num / den else step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps) end end end local function learn_from_epoch() print() --print('rewards:', trial_rewards) --for _, v in ipairs(trial_rewards) do insert(all_rewards, v) end local current_cost = trial_rewards[0] -- may be nil! if cfg.unperturbed_trial then local nth_place = unperturbed_rank(trial_rewards, current_cost) -- a rank of 1 means our gradient is uninformative. print(("test trial: %d out of %d"):format(nth_place, #trial_rewards)) end local delta_rewards = {} -- only used for logging. if cfg.negate_trials then for i = 1, cfg.epoch_trials do local ind = (i - 1) * 2 + 1 local pos = trial_rewards[ind + 0] local neg = trial_rewards[ind + 1] delta_rewards[i] = abs(pos - neg) end end local indices = collect_best_indices() print("best trials:", indices) local top_rewards = {} for i = 1, #trial_rewards do top_rewards[i] = 0 end for _, ind in ipairs(indices) do local sind = (ind - 1) * 2 + 1 top_rewards[sind + 0] = trial_rewards[sind + 0] top_rewards[sind + 1] = trial_rewards[sind + 1] end --print("top:", top_rewards) if cfg.negate_trials then local top_delta_rewards = {} -- only used for printing. for i, ind in ipairs(indices) do local sind = (ind - 1) * 2 + 1 top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1]) end print("best deltas:", top_delta_rewards) end local step if cfg.negate_trials then step = make_step_paired(top_rewards, current_cost) else step = make_step(top_rewards) end local step_mean, step_dev = calc_mean_dev(step) print("step mean:", step_mean) print("step stddev:", step_dev) local momstep_mean, momstep_dev = 0, 0 if cfg.adamant then amsgrad(step) momstep_mean, momstep_dev = calc_mean_dev(step) print("amsgrad mean:", momstep_mean) print("amsgrad stddev:", momstep_dev) end for i, v in ipairs(base_params) do base_params[i] = v + cfg.learning_rate * step[i] - cfg.weight_decay * v end local trial_mean, trial_std = calc_mean_dev(trial_rewards) local delta_mean, delta_std = calc_mean_dev(delta_rewards) local weight_mean, weight_std = calc_mean_dev(base_params) log_csv{ epoch = epoch_i, trial_mean = trial_mean, trial_std = trial_std, delta_mean = delta_mean, delta_std = delta_std, step_std = step_dev, adam_std = momstep_dev, weight_mean = weight_mean, weight_std = weight_std, test_trial = current_cost or 0, } if cfg.enable_network then network:distribute(base_params) network:save(cfg.params_fn) else print("note: not updating weights in playable mode.") end print() end local function joypad_mash(button) local jp_mash = { up = false, down = false, left = false, right = false, A = false, B = false, select = false, start = false, } assert(jp_mash[button] == false, "invalid button: "..tostring(button), 1) jp_mash[button] = emu.framecount() % 2 == 1 joypad.write(1, jp_mash) end local function do_reset() local state = game.get_state() -- be a little more descriptive. if state == 'dead' and game.get_timer() == 0 then state = 'timeup' end if trial_i >= 0 then if trial_i == 0 then print('test trial reward:', reward, "("..state..")") elseif cfg.negate_trials then --local dir = trial_neg and "negative" or "positive" --print('trial', trial_i, dir, 'reward:', reward, "("..state..")") if trial_neg then local pos = trial_rewards[#trial_rewards] local neg = reward local fmt = "trial %i rewards: %+i, %+i (%s, %s)" print(fmt:format(trial_i, pos, neg, last_trial_state, state)) end last_trial_state = state else print('trial', trial_i, 'reward:', reward, "("..state..")") end else print("reward:", reward, "("..state..")") end if trial_i >= 0 then if trial_i == 0 or not cfg.negate_trials then trial_rewards[trial_i] = reward else trial_rewards[#trial_rewards + 1] = reward end end if epoch_i == 0 or (trial_i == cfg.epoch_trials and trial_neg) then if epoch_i > 0 then learn_from_epoch() end if not cfg.playback_mode then epoch_i = epoch_i + 1 end prepare_epoch() end if game.get_state() == 'loading' then game.advance() end -- kind of a hack. reward = 0 powerup_old = game.R(0x754) status_old = game.R(0x756) coins_old = game.R(0x7ED) * 10 + game.R(0x7EE) score_old = game.get_score() -- set number of lives. (mario gets n+1 chances) game.W(0x75A, cfg.starting_lives) if cfg.start_big then -- make mario "super". game.W(0x754, 0) game.W(0x756, 1) end max_time = min(6 * sqrt(480 / cfg.epoch_trials * (epoch_i - 1)) + 60, cfg.cap_time) max_time = ceil(max_time) if once then savestate.load(startsave) else savestate.save(startsave) end once = true jp = nil screen_scroll_delta = 0 trial_frames = 0 emu.frameadvance() -- prevents emulator from quirking up. load_next_trial() reset = false end local function init() network = make_network(gcfg.input_size) network:reset() network:print() print("parameters:", network.n_param) if cfg.init_zeros then local W = network:collect() for i, w in ipairs(W) do W[i] = 0 end network:distribute(W) end emu.poweron() emu.unpause() emu.speedmode("turbo") while emu.framecount() < 195 do -- FIXME: don't hardcode this. joypad_mash('start') emu.frameadvance() end --print(emu.framecount()) local res, err = pcall(network.load, network, cfg.params_fn) if res == false then print(err) end end local function prepare_reset() if cfg.playback_mode then return end reset = true end local function draw_framecount(x, y, frames) local tf0 = frames % 1000 local tf1 = (frames % 1000000 - tf0) / 1000 local tf2 = (frames - tf0 - tf1) / 1000000 gui.text(x, y, ("%03i,%03i,%03i"):format(tf2,tf1,tf0), '#FFFFFF', '#0000003F') end local function doit(dummy) local ingame_paused = game.get_state() == "paused" -- every few frames mario stands still, forcibly decrease the timer. -- this includes having the game paused. -- TODO: more robust. doesn't detect moonwalking against a wall. -- well, that shouldn't happen anymore now that i've disabled left+right. local timer = game.get_timer() if ingame_paused or random() > 1 - cfg.timer_loser and game.R(0x1D) == 0 and game.R(0x57) == 0 then timer = timer - 1 end if not cfg.playback_mode then timer = clamp(timer, 0, max_time) if cfg.enable_network then game.set_timer(timer) end end draw_framecount(12, 215, total_frames) screen_scroll_delta = screen_scroll_delta + game.R(0x775) if dummy == true then -- don't invoke AI this frame. (keep holding the old inputs) gui.text(89, 16, ("%+5i"):format(reward), '#FFFFFF', '#0000003F') return end empty(game.sprite_input) empty(game.tile_input) empty(game.extra_input) -- TODO: check if mario is in a playable state. local x, y = game.getxy(0, 0x86, 0xCE, 0x6D, 0xB5) local powerup = game.R(0x754) local status = game.R(0x756) game.mark_sprite(x + 8, y + 24, -powerup - 1) local vx, vy = game.S(0x57), game.S(0x9F) insert(game.extra_input, vx * 16) insert(game.extra_input, vy * 16) if cfg.time_inputs then for i=2,5 do local v = band(trial_frames, lshift(1, i)) == 0 and -181 or 181 insert(game.extra_input, v) end end game.handle_enemies() game.handle_fireballs() --game.handle_blocks() -- blocks being hit. not interactable; we don't care! game.handle_hammers() game.handle_misc() game.handle_tiles() local coins = game.R(0x7ED) * 10 + game.R(0x7EE) local coins_delta = coins - coins_old -- handle wrap-around. if coins_delta < 0 then coins_delta = 100 + coins - coins_old end -- remember that 0 is big mario and 1 is small mario. local powerup_delta = powerup_old - powerup -- 2 is fire mario. local status_delta = clamp(status - status_old, -1, 1) local flagpole_bonus = game.R(0xE) == 4 and cfg.frameskip or 0 --local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus local score_delta = game.get_score() - score_old if score_delta < 0 then score_delta = 0 end local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus screen_scroll_delta = 0 if cfg.decrement_reward and reward_delta == 0 then reward_delta = reward_delta - 1 end if not ingame_paused then reward = reward + reward_delta end --gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F') gui.text(89, 16, ("%+5i"):format(reward), '#FFFFFF', '#0000003F') if game.get_state() == 'dead' and state_old ~= 'dead' then --print("dead. lives remaining:", game.R(0x75A, 0)) if game.R(0x75A, 0) == 0 then prepare_reset() end end if game.get_state() == 'lose' then -- this shouldn't happen if we catch the deaths as above. print("ran out of lives.") if not cfg.playback_mode then prepare_reset() end end -- lose a point for every frame paused. --if ingame_paused then reward = reward - 1 end if ingame_paused then reward = reward - 402; prepare_reset() end -- if we've run out of time while the game is paused... -- that's cheating! unpause. force_start = ingame_paused and timer == 0 local X = {} for i, v in ipairs(game.sprite_input) do insert(X, v / 256) end for i, v in ipairs(game.extra_input) do insert(X, v / 256) end nn.reshape(X, 1, gcfg.input_size) nn.reshape(game.tile_input, 1, gcfg.tile_count) trial_frames = trial_frames + cfg.frameskip if cfg.enable_network and game.get_state() == 'playing' or ingame_paused then total_frames = total_frames + cfg.frameskip local outputs = network:forward({[nn_x]=X, [nn_tx]=game.tile_input}) local eps = lerp(cfg.eps_start, cfg.eps_stop, total_frames / cfg.eps_frames) if cfg.det_epsilon and random() < eps then local i = floor(random() * #gcfg.jp_lut) + 1 jp = copy(gcfg.jp_lut[i], jp) else local choose = cfg.deterministic and argmax or softchoice local ind = choose(unpack(outputs[nn_z])) jp = copy(gcfg.jp_lut[ind], jp) end if force_start then jp = { up = false, down = false, left = false, right = false, A = false, B = false, start = force_start_old, select = false, } end end coins_old = coins powerup_old = powerup status_old = status force_start_old = force_start state_old = game.get_state() score_old = game.get_score() end init() while true do gui.text(4, 12, game.get_state(), '#FFFFFF', '#0000003F') while gcfg.bad_states[game.get_state()] do -- mash the start button until we have control. joypad_mash('start') prepare_reset() game.advance() gui.text(4, 12, game.get_state(), '#FFFFFF', '#0000003F') while game.get_state() == "loading" do game.advance() end -- kind of a hack. state_old = game.get_state() end if reset then do_reset() lagless_count = 0 end if not cfg.enable_network then -- infinite time cheat. super handy for testing. if game.R(0xE) == 8 then game.set_timer(667) poketime = true elseif poketime then poketime = false game.set_timer(1) end -- infinite lives. game.W(0x75A, 1) end local doot = jp == nil or lagless_count % cfg.frameskip == 0 doit(not doot) -- jp might still be nil if we're not ingame or we're not playing. if jp ~= nil then joypad.write(1, jp) end game.advance() -- remember, this skips lag frames. lagless_count = lagless_count + 1 end