local globalize = require("strict") -- configuration. local cfg = require("config") local gcfg = require("gameconfig") -- state. local epoch_i = 0 local base_params local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled. local trial_neg = true local trial_params --= {} local trial_rewards = {} local trials_remaining = 0 local mom1 -- first moments in AMSgrad. local mom2 -- second moments in AMSgrad. local mom2max -- running element-wise maximum of mom2. local es -- evolution strategy. local trial_frames = 0 local total_frames = 0 local lagless_count = 0 local force_start = false local force_start_old = false local startsave = savestate.create(1) local any_random = cfg.starting_world == 0 or cfg.starting_level == 0 local poketime = false local max_time local jp local screen_scroll_delta local reward local powerup_old local status_old local coins_old local score_old local state_saved = false local reset = true local state_old = '' local last_trial_state -- localize some stuff. local assert = assert local print = print local ipairs = ipairs local pairs = pairs local select = select local open = io.open local abs = math.abs local floor = math.floor local ceil = math.ceil local min = math.min local max = math.max local exp = math.exp local pow = math.pow local log = math.log local sqrt = math.sqrt local random = math.random local randomseed = math.randomseed local insert = table.insert local remove = table.remove local unpack = table.unpack or unpack local sort = table.sort local band = bit.band local bor = bit.bor local bxor = bit.bxor local bnot = bit.bnot local lshift = bit.lshift local rshift = bit.rshift local arshift = bit.arshift local rol = bit.rol local ror = bit.ror local emu = emu local gui = gui local util = require("util") local argmax = util.argmax local argsort = util.argsort local calc_mean_dev = util.calc_mean_dev local clamp = util.clamp local copy = util.copy local empty = util.empty local lerp = util.lerp local softchoice = util.softchoice local unperturbed_rank = util.unperturbed_rank local game = require("smb") game.overlay = cfg.enable_overlay -- utilities. local log_map = { epoch = 1, trial_mean = 2, trial_std = 3, delta_mean = 4, delta_std = 5, step_std = 6, adam_std = 7, weight_mean = 8, weight_std = 9, test_trial = 10, } local function log_csv(t) if cfg.log_fn == nil then return end local f = open(cfg.log_fn, 'a') if f == nil then error("Failed to open log file "..cfg.log_fn) end local values = {} for k, v in pairs(t) do local i = log_map[k] if i == nil then error("Unexpected log key "..tostring(k)) end values[i] = v end for k, i in pairs(log_map) do if values[i] == nil then error("Missing log key "..tostring(k)) end end for i, v in ipairs(values) do f:write(tostring(v)) if i ~= #values then f:write(",") end end f:write('\n') f:close() end -- network parameters. package.loaded['nn'] = nil -- DEBUG local nn = require("nn") local network local nn_x, nn_tx, nn_ty, nn_tz, nn_y, nn_z local function make_network(input_size) nn_x = nn.Input({input_size}) nn_tx = nn.Input({gcfg.tile_count}) nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2)) nn_tz = nn_ty:feed(nn.Reshape{13, 17 * 2}) nn_tz = nn_tz:feed(nn.DenseBroadcast(5)) nn_tz = nn_tz:feed(nn.Relu()) -- note: due to a quirk in Merge, we don't need to flatten nn_tz. nn_y = nn.Merge() nn_x:feed(nn_y) nn_tz:feed(nn_y) --[[ nn_y = nn_y:feed(nn.Dense(128)) if cfg.deterministic then nn_y = nn_y:feed(nn.Relu()) else nn_y = nn_y:feed(nn.Gelu()) end if cfg.layernorm then nn_y = nn_y:feed(nn.LayerNorm()) end --]] nn_z = nn_y nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut)) nn_z = nn_z:feed(nn.Softmax()) return nn.Model({nn_x, nn_tx}, {nn_z}) end -- learning and evaluation. local ars = require("ars") local function prepare_epoch() trial_neg = false base_params = network:collect() if cfg.playback_mode then return end print('preparing epoch '..tostring(epoch_i)..'.') empty(trial_rewards) local precision if cfg.graycode then precision = (pow(cfg.deviation, 1/-0.51175585) - 8.68297257) / 1.66484392 print(("chosen precision: %.2f"):format(precision)) end local dummy if es == 'ars' then trial_params, dummy = es:ask(precision) else trial_params, dummy = es:ask() end trial_i = -1 end local function load_next_trial() if cfg.negate_trials then trial_neg = not trial_neg end trial_i = trial_i + 1 if trial_i == 0 and not cfg.unperturbed_trial then trial_i = 1 end if trial_i > 0 then --print('loading trial', trial_i) network:distribute(trial_params[trial_i]) else --print("test trial") network:distribute(base_params) end end local function learn_from_epoch() print() local current_cost = trial_rewards[0] -- may be nil! if cfg.unperturbed_trial then local nth_place = unperturbed_rank(trial_rewards, current_cost) -- a rank of 1 means our gradient is uninformative. print(("test trial: %d out of %d"):format(nth_place, #trial_rewards)) end local delta_rewards = {} -- only used for logging. if cfg.negate_trials then for i = 1, #trial_rewards, 2 do local ind = floor(i / 2) + 1 local pos = trial_rewards[i + 0] local neg = trial_rewards[i + 1] delta_rewards[ind] = abs(pos - neg) end end if es == 'ars' then es:tell(trial_rewards, current_cost) else es:tell(trial_rewards) end local step_mean, step_dev = 0, 0 --[[ TODO local step_mean, step_dev = calc_mean_dev(step) print("step mean:", step_mean) print("step stddev:", step_dev) --]] local momstep_mean, momstep_dev = 0, 0 --[[ TODO if cfg.adamant then amsgrad(step) momstep_mean, momstep_dev = calc_mean_dev(step) print("amsgrad mean:", momstep_mean) print("amsgrad stddev:", momstep_dev) end --]] base_params = es:params() for i, v in ipairs(base_params) do base_params[i] = v * (1 - cfg.weight_decay) end es:params(base_params) local trial_mean, trial_std = calc_mean_dev(trial_rewards) local delta_mean, delta_std = calc_mean_dev(delta_rewards) local weight_mean, weight_std = calc_mean_dev(base_params) log_csv{ epoch = epoch_i, trial_mean = trial_mean, trial_std = trial_std, delta_mean = delta_mean, delta_std = delta_std, step_std = step_dev, adam_std = momstep_dev, weight_mean = weight_mean, weight_std = weight_std, test_trial = current_cost or 0, } if cfg.enable_network then network:distribute(base_params) network:save(cfg.params_fn) else print("note: not updating weights in playable mode.") end print() end local function joypad_mash(button) local jp_mash = { up = false, down = false, left = false, right = false, A = false, B = false, select = false, start = false, } assert(jp_mash[button] == false, "invalid button: "..tostring(button), 1) jp_mash[button] = emu.framecount() % 2 == 1 joypad.write(1, jp_mash) end local function loadlevel(world, level) -- TODO: move to smb.lua. rename to load_level. if world == 0 then world = random(1, 8) end if level == 0 then level = random(1, 4) end emu.poweron() while emu.framecount() < 60 do if emu.framecount() == 32 then local area = game.area_lut[world * 10 + level] game.W(0x75F, world - 1) game.W(0x75C, level - 1) game.W(0x760, area) end if emu.framecount() == 42 then game.W(0x7A0, 0) -- world screen timer (reduces startup time) end joypad_mash('start') emu.frameadvance() end end local function do_reset() local state = game.get_state() -- be a little more descriptive. if state == 'dead' and game.get_timer() == 0 then state = 'timeup' end if trial_i >= 0 then if trial_i == 0 then print('test trial reward:', reward, "("..state..")") elseif cfg.negate_trials then --local dir = trial_neg and "negative" or "positive" --print('trial', trial_i, dir, 'reward:', reward, "("..state..")") if trial_neg then local pos = trial_rewards[#trial_rewards] local neg = reward local fmt = "trial %i rewards: %+i, %+i (%s, %s)" print(fmt:format(floor(trial_i / 2), pos, neg, last_trial_state, state)) end last_trial_state = state else print('trial', trial_i, 'reward:', reward, "("..state..")") end else print("reward:", reward, "("..state..")") end if trial_i >= 0 then if trial_i == 0 or not cfg.negate_trials then trial_rewards[trial_i] = reward else trial_rewards[#trial_rewards + 1] = reward end end if epoch_i == 0 or (trial_i == #trial_params and trial_neg) then if epoch_i > 0 then learn_from_epoch() end if not cfg.playback_mode then epoch_i = epoch_i + 1 end prepare_epoch() if any_random then loadlevel(cfg.starting_world, cfg.starting_level) state_saved = false end end max_time = min(6 * sqrt(480 / #trial_params * (epoch_i - 1)) + 60, cfg.cap_time) max_time = ceil(max_time) -- TODO: game.reset(cfg.starting_lives, cfg.start_big) if game.get_state() == 'loading' then game.advance() end -- kind of a hack. reward = 0 powerup_old = game.R(0x754) status_old = game.R(0x756) coins_old = game.R(0x7ED) * 10 + game.R(0x7EE) score_old = game.get_score() -- set number of lives. (mario gets n+1 chances) game.W(0x75A, cfg.starting_lives) if cfg.start_big then -- make mario "super". game.W(0x754, 0) game.W(0x756, 1) end -- end of game.reset() if state_saved then savestate.load(startsave) else savestate.save(startsave) end state_saved = true jp = nil screen_scroll_delta = 0 trial_frames = 0 emu.frameadvance() -- prevents emulator from quirking up. load_next_trial() reset = false end local function init() network = make_network(gcfg.input_size) network:reset() network:print() print("parameters:", network.n_param) if cfg.init_zeros then local W = network:collect() for i, w in ipairs(W) do W[i] = 0 end network:distribute(W) end emu.poweron() emu.unpause() if not cfg.playable_mode then emu.speedmode("turbo") end if not any_random then loadlevel(cfg.starting_world, cfg.starting_level) end local res, err = pcall(network.load, network, cfg.params_fn) if res == false then print(err) end if cfg.es == 'ars' then es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials, cfg.learning_rate, cfg.deviation, cfg.negate_trials) else error("Unknown evolution strategy specified: " + tostring(cfg.es)) end end local function prepare_reset() if cfg.playback_mode then return end reset = true end local function draw_framecount(x, y, frames) local tf0 = frames % 1000 local tf1 = (frames % 1000000 - tf0) / 1000 local tf2 = (frames - tf0 - tf1) / 1000000 gui.text(x, y, ("%03i,%03i,%03i"):format(tf2,tf1,tf0), '#FFFFFF', '#0000003F') end local function doit(dummy) local ingame_paused = game.get_state() == "paused" -- every few frames mario stands still, forcibly decrease the timer. -- this includes having the game paused. -- TODO: more robust. doesn't detect moonwalking against a wall. -- well, that shouldn't happen anymore now that i've disabled left+right. local timer = game.get_timer() if ingame_paused or random() > 1 - cfg.timer_loser and game.R(0x1D) == 0 and game.R(0x57) == 0 then timer = timer - 1 end if not cfg.playback_mode then timer = clamp(timer, 0, max_time) if cfg.enable_network then game.set_timer(timer) end end draw_framecount(12, 215, total_frames) screen_scroll_delta = screen_scroll_delta + game.R(0x775) if dummy == true then -- don't invoke AI this frame. (keep holding the old inputs) gui.text(89, 16, ("%+5i"):format(reward), '#FFFFFF', '#0000003F') return end empty(game.sprite_input) empty(game.tile_input) empty(game.extra_input) local controllable = game.R(0x757) == 0 and game.R(0x758) == 0 local x, y = game.getxy(0, 0x86, 0xCE, 0x6D, 0xB5) local powerup = game.R(0x754) local status = game.R(0x756) game.mark_sprite(x + 8, y + 24, -powerup - 1) local vx, vy = game.S(0x57), game.S(0x9F) insert(game.extra_input, vx * 16) insert(game.extra_input, vy * 16) if cfg.time_inputs then for i=2,5 do local v = band(trial_frames, lshift(1, i)) == 0 and -181 or 181 insert(game.extra_input, v) end end game.handle_enemies() game.handle_fireballs() --game.handle_blocks() -- blocks being hit. not interactable; we don't care! game.handle_hammers() game.handle_misc() game.handle_tiles() local coins = game.R(0x7ED) * 10 + game.R(0x7EE) local coins_delta = coins - coins_old -- handle wrap-around. if coins_delta < 0 then coins_delta = 100 + coins - coins_old end -- remember that 0 is big mario and 1 is small mario. local powerup_delta = powerup_old - powerup -- 2 is fire mario. local status_delta = clamp(status - status_old, -1, 1) local flagpole_bonus = game.R(0xE) == 4 and cfg.frameskip or 0 --local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus local score_delta = game.get_score() - score_old if score_delta < 0 then score_delta = 0 end local reward_delta = screen_scroll_delta + cfg.score_multiplier * (score_delta + flagpole_bonus) screen_scroll_delta = 0 if cfg.decrement_reward and reward_delta == 0 then reward_delta = reward_delta - 1 end if not ingame_paused and game.get_state() ~= 'win_walking' then -- note that we exclude points gained walking into the castle. -- this way, we avoid adding the timer-based fireworks to our reward, -- which are basically unwanted noise due to the way they trigger. if flagpole_bonus > 0 or controllable then reward = reward + reward_delta end end --gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F') gui.text(89, 16, ("%+5i"):format(reward), '#FFFFFF', '#0000003F') if game.get_state() == 'dead' and state_old ~= 'dead' then --print("dead. lives remaining:", game.R(0x75A, 0)) if game.R(0x75A, 0) == 0 then prepare_reset() end end if game.get_state() == 'lose' then -- this shouldn't happen if we catch the deaths as above. print("ran out of lives.") if not cfg.playback_mode then prepare_reset() end end -- lose a point for every frame paused. --if ingame_paused then reward = reward - 1 end if ingame_paused then reward = reward - 402; prepare_reset() end -- if we've run out of time while the game is paused... -- that's cheating! unpause. force_start = ingame_paused and timer == 0 local X = {} for i, v in ipairs(game.sprite_input) do insert(X, v / 256) end for i, v in ipairs(game.extra_input) do insert(X, v / 256) end nn.reshape(X, 1, gcfg.input_size) nn.reshape(game.tile_input, 1, gcfg.tile_count) trial_frames = trial_frames + cfg.frameskip if cfg.enable_network and game.get_state() == 'playing' or ingame_paused then total_frames = total_frames + cfg.frameskip local outputs = network:forward({[nn_x]=X, [nn_tx]=game.tile_input}) local eps = lerp(cfg.eps_start, cfg.eps_stop, total_frames / cfg.eps_frames) if cfg.det_epsilon and random() < eps then local i = floor(random() * #gcfg.jp_lut) + 1 jp = copy(gcfg.jp_lut[i], jp) else local choose = cfg.deterministic and argmax or softchoice local ind = choose(unpack(outputs[nn_z])) jp = copy(gcfg.jp_lut[ind], jp) end if force_start then jp = { up = false, down = false, left = false, right = false, A = false, B = false, start = force_start_old, select = false, } end end coins_old = coins powerup_old = powerup status_old = status force_start_old = force_start state_old = game.get_state() score_old = game.get_score() end init() while true do gui.text(4, 12, game.get_state(), '#FFFFFF', '#0000003F') while gcfg.bad_states[game.get_state()] do -- mash the start button until we have control. joypad_mash('start') prepare_reset() game.advance() gui.text(4, 12, game.get_state(), '#FFFFFF', '#0000003F') while game.get_state() == "loading" do game.advance() end -- kind of a hack. state_old = game.get_state() end if reset then do_reset() lagless_count = 0 end if not cfg.enable_network then -- infinite time cheat. super handy for testing. if game.R(0xE) == 8 then game.set_timer(667) poketime = true elseif poketime then poketime = false game.set_timer(1) end -- infinite lives. game.W(0x75A, 1) end local doot = jp == nil or lagless_count % cfg.frameskip == 0 doit(not doot) -- jp might still be nil if we're not ingame or we're not playing. if jp ~= nil then joypad.write(1, jp) end game.advance() -- remember, this skips lag frames. lagless_count = lagless_count + 1 end