preset = arg or "" -- must be done before requiring strict! local globalize = require("strict") -- configuration. local cfg = require("config") local gcfg = require("gameconfig") -- state. local params_fn local std_fn local epoch_i = 0 local base_params local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled. local trial_neg = true local trial_params --= {} local trial_rewards = {} local trials_remaining = 0 local es -- evolution strategy. local trial_frames = 0 local total_frames = 0 local lagless_count = 0 local decisions_made = 0 local force_start = false local force_start_old = false local startsave = savestate.create(1) local any_random = cfg.starting_world == 0 or cfg.starting_level == 0 local poketime = false local max_time local jp local screen_scroll_delta local reward local powerup_old local status_old local coins_old local score_old local state_saved = false local reset = true local state_old = '' local last_trial_state -- localize some stuff. local abs = math.abs local assert = assert local ceil = math.ceil local collectgarbage = collectgarbage local exp = math.exp local floor = math.floor local insert = table.insert local ipairs = ipairs local log = math.log local max = math.max local min = math.min local open = io.open local pairs = pairs local pow = math.pow local print = print local random = math.random local randomseed = math.randomseed local remove = table.remove local select = select local sort = table.sort local sqrt = math.sqrt local unpack = table.unpack or unpack local band = bit.band local bor = bit.bor local bxor = bit.bxor local bnot = bit.bnot local lshift = bit.lshift local rshift = bit.rshift local arshift = bit.arshift local rol = bit.rol local ror = bit.ror local emu = emu local gui = gui local util = require("util") local argmax = util.argmax local argsort = util.argsort local calc_mean_dev = util.calc_mean_dev local clamp = util.clamp local copy = util.copy local empty = util.empty local lerp = util.lerp local softchoice = util.softchoice local unperturbed_rank = util.unperturbed_rank local exists = util.exists local game = require("smb") game.overlay = cfg.enable_overlay -- utilities. local log_map = { epoch = 1, trial_mean = 2, trial_std = 3, delta_mean = 4, delta_std = 5, step_std = 6, weight_mean = 7, -- TODO: rename to param_mean. weight_std = 8, -- TODO: rename to param_std. test_trial = 9, decisions = 10, } local function log_csv(t) if cfg.log_fn == nil then return end local f = open(cfg.log_fn, 'a') if f == nil then error("Failed to open log file "..cfg.log_fn) end local values = {} for k, v in pairs(t) do local i = log_map[k] if i == nil then error("Unexpected log key "..tostring(k)) end values[i] = v end for k, i in pairs(log_map) do if values[i] == nil then error("Missing log key "..tostring(k)) end end for i, v in ipairs(values) do f:write(tostring(v)) if i ~= #values then f:write(",") end end f:write('\n') f:close() end -- network parameters. package.loaded['nn'] = nil -- DEBUG local nn = require("nn") local network local nn_x, nn_tx, nn_ty, nn_tz, nn_y, nn_z local function make_network(input_size) nn_x = nn.Input({input_size}) nn_tx = nn.Input({gcfg.tile_count}) nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2)) nn_tz = nn_ty if cfg.reduce_tiles then nn_tz = nn_tz:feed(nn.Reshape{11, 17 * 2}) nn_tz = nn_tz:feed(nn.DenseBroadcast(5, true)) nn_tz = nn_tz:feed(nn.Relu()) -- note: due to a quirk in Merge, we don't need to flatten nn_tz. end nn_y = nn.Merge() nn_x:feed(nn_y) nn_tz:feed(nn_y) if cfg.hidden then nn_y = nn_y:feed(nn.Dense(cfg.hidden_size, true)) if cfg.deterministic then nn_y = nn_y:feed(nn.Relu()) else nn_y = nn_y:feed(nn.Gelu()) end if cfg.layernorm then nn_y = nn_y:feed(nn.LayerNorm()) end end nn_z = nn_y nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut), true, cfg.bias_out) nn_z = nn_z:feed(nn.Softmax()) return nn.Model({nn_x, nn_tx}, {nn_z}) end -- learning and evaluation. local ars = require("ars") local snes = require("snes") local xnes = require("xnes") local function prepare_epoch() trial_neg = false base_params = network:collect() if cfg.playback_mode then return end print('preparing epoch '..tostring(epoch_i)..'...') empty(trial_rewards) if cfg.es == 'xnes' then print("sigma:", es.sigma) elseif cfg.es == 'snes' then local sigma_mean, sigma_dev = calc_mean_dev(es.std) --print("sigma:", sigma_mean, sigma_dev) print("sigma 50%:", sigma_mean) print("sigma 95%:", sigma_mean + sigma_dev * 1.64485) end local precision if cfg.graycode then precision = (pow(cfg.deviation, 1/-0.51175585) - 8.68297257) / 1.66484392 print(("chosen precision: %.2f"):format(precision)) end local dummy if cfg.es == 'ars' then trial_params, dummy = es:ask(precision) elseif cfg.es == 'snes' then trial_params, dummy = es:ask_mix() else trial_params, dummy = es:ask() end trial_i = -1 end local function load_next_trial() if cfg.negate_trials then trial_neg = not trial_neg else trial_neg = true end trial_i = trial_i + 1 if trial_i == 0 and not cfg.unperturbed_trial then trial_neg = false trial_i = 1 end if trial_i > 0 then --print('loading trial', trial_i) network:distribute(trial_params[trial_i]) else --print("test trial") network:distribute(base_params) end end local function learn_from_epoch() print() local current_cost = trial_rewards[0] -- may be nil! if cfg.unperturbed_trial then local nth_place = unperturbed_rank(trial_rewards, current_cost) -- a rank of 1 means our gradient is uninformative. print(("test trial: %d out of %d"):format(nth_place, #trial_rewards)) end local delta_rewards = {} -- only used for logging. if cfg.negate_trials then for i = 1, #trial_rewards, 2 do local ind = floor(i / 2) + 1 local pos = trial_rewards[i + 0] local neg = trial_rewards[i + 1] delta_rewards[ind] = abs(pos - neg) end end local step if cfg.es == 'ars' and cfg.ars_lips then step = es:tell(trial_rewards, current_cost) else step = es:tell(trial_rewards) end local step_mean, step_dev = calc_mean_dev(step) print("step mean:", step_mean) print("step stddev:", step_dev) es:decay(cfg.param_decay, cfg.sigma_decay) base_params = es:params() local trial_mean, trial_std = calc_mean_dev(trial_rewards) local delta_mean, delta_std = calc_mean_dev(delta_rewards) local param_mean, param_std = calc_mean_dev(base_params) log_csv{ epoch = epoch_i, trial_mean = trial_mean, trial_std = trial_std, delta_mean = delta_mean, delta_std = delta_std, step_std = step_dev, weight_mean = param_mean, weight_std = param_std, test_trial = current_cost or 0, decisions = decisions_made, } if cfg.enable_network then network:distribute(base_params) network:save(params_fn) if cfg.es == 'snes' then local f = assert(open(std_fn, "w")) for _, v in ipairs(es.std) do f:write(("%f\n"):format(v)) end f:close() end else print("note: not updating params in playable mode.") end print() end local function joypad_mash(button) local jp_mash = { up = false, down = false, left = false, right = false, A = false, B = false, select = false, start = false, } assert(jp_mash[button] == false, "invalid button: "..tostring(button), 1) jp_mash[button] = emu.framecount() % 2 == 1 joypad.write(1, jp_mash) end local function loadlevel(world, level) -- TODO: move to smb.lua. rename to load_level. if world == 0 then world = random(1, 8) end if level == 0 then level = random(1, 4) end emu.poweron() while emu.framecount() < 60 do if emu.framecount() == 32 then local area = game.area_lut[world * 10 + level] game.W(0x75F, world - 1) game.W(0x75C, level - 1) game.W(0x760, area) end if emu.framecount() == 42 then game.W(0x7A0, 0) -- world screen timer (reduces startup time) end joypad_mash('start') emu.frameadvance() end end local function do_reset() local state = game.get_state() -- be a little more descriptive. if state == 'dead' and game.get_timer() == 0 then state = 'timeup' end if trial_i >= 0 then if trial_i == 0 then print('test trial reward:', reward, "("..state..")") elseif cfg.negate_trials then --local dir = trial_neg and "negative" or "positive" --print('trial', trial_i, dir, 'reward:', reward, "("..state..")") if trial_neg then local pos = trial_rewards[#trial_rewards] local neg = reward local fmt = "trial %i rewards: %+i, %+i (%s, %s)" print(fmt:format(floor(trial_i / 2), pos, neg, last_trial_state, state)) end last_trial_state = state else print('trial', trial_i, 'reward:', reward, "("..state..")") end if trial_i == 0 or not cfg.negate_trials then trial_rewards[trial_i] = reward else trial_rewards[#trial_rewards + 1] = reward end end if epoch_i == 0 or (trial_i == #trial_params and trial_neg) then if epoch_i > 0 then learn_from_epoch() end if not cfg.playback_mode then epoch_i = epoch_i + 1 end prepare_epoch() collectgarbage() if any_random then loadlevel(cfg.starting_world, cfg.starting_level) state_saved = false end end max_time = 6 * sqrt(10 * (epoch_i - 1)) + 60 max_time = clamp(max_time, cfg.min_time, cfg.max_time) max_time = ceil(max_time) -- TODO: game.reset(cfg.starting_lives, cfg.start_big) if game.get_state() == 'loading' then game.advance() end -- kind of a hack. reward = 0 powerup_old = game.R(0x754) status_old = game.R(0x756) coins_old = game.R(0x7ED) * 10 + game.R(0x7EE) score_old = game.get_score() -- set number of lives. (mario gets n+1 chances) game.W(0x75A, cfg.starting_lives) if cfg.start_big then -- make mario "super". game.W(0x754, 0) game.W(0x756, 1) end -- end of game.reset() if state_saved then savestate.load(startsave) else savestate.save(startsave) end state_saved = true jp = nil screen_scroll_delta = 0 trial_frames = 0 emu.frameadvance() -- prevents emulator from quirking up. load_next_trial() reset = false end local function init() network = make_network(gcfg.input_size) network:reset() network:print() print("parameters:", network.n_param) if cfg.init_zeros then local W = network:collect() for i, w in ipairs(W) do W[i] = 0 end network:distribute(W) end emu.poweron() emu.unpause() local playing = cfg.playable_mode or cfg.playback_mode if not playing then emu.speedmode("turbo") end if not any_random then loadlevel(cfg.starting_world, cfg.starting_level) end params_fn = cfg.params_fn or ('network%07i.txt'):format(network.n_param) std_fn = params_fn:gsub(".txt", "")..".sigma.txt" if exists(params_fn) then network:load(params_fn) end if cfg.es == 'xnes' then -- if you get an out of memory error, you can't use xNES. sorry! -- maybe there'll be a patch for FCEUX in the future. es = xnes.Xnes(network.n_param, cfg.epoch_trials, cfg.base_rate, cfg.deviation, cfg.negate_trials) elseif cfg.es == 'snes' then es = snes.Snes(network.n_param, cfg.epoch_trials, cfg.base_rate, cfg.deviation, cfg.negate_trials) -- TODO: clean this up into an interface: es.min_refresh = cfg.min_refresh if exists(std_fn) then local f = assert(open(std_fn, "r")) for i=1, network.n_param do es.std[i] = assert(tonumber(assert(f:read()))) end f:close() end elseif cfg.es == 'ars' then es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials, cfg.base_rate, cfg.deviation, cfg.negate_trials, cfg.momentum) else error("Unknown evolution strategy specified: " + tostring(cfg.es)) end es.param_rate = cfg.param_rate es.sigma_rate = cfg.sigma_rate es.covar_rate = cfg.covar_rate es.rate_init = cfg.sigma_rate -- just for SNES? es:params(network:collect()) end local function prepare_reset() if cfg.playback_mode then return end reset = true end local function draw_framecount(x, y, frames) local tf0 = frames % 1000 local tf1 = (frames % 1000000 - tf0) / 1000 local tf2 = (frames - tf0 - tf1) / 1000000 gui.text(x, y, ("%03i,%03i,%03i"):format(tf2,tf1,tf0), '#FFFFFF', '#0000003F') end local function doit(dummy) local ingame_paused = game.get_state() == "paused" -- every few frames mario stands still, forcibly decrease the timer. -- this includes having the game paused. -- TODO: more robust. doesn't detect moonwalking against a wall. -- well, that shouldn't happen anymore now that i've disabled left+right. local timer = game.get_timer() if ingame_paused or random() > 1 - cfg.timer_loser and game.R(0x1D) == 0 and game.R(0x57) == 0 then timer = timer - 1 end if not cfg.playback_mode then timer = clamp(timer, 0, max_time) if cfg.enable_network then game.set_timer(timer) end end draw_framecount(12, 215, decisions_made) screen_scroll_delta = screen_scroll_delta + game.R(0x775) if dummy == true then -- don't invoke AI this frame. (keep holding the old inputs) gui.text(89, 16, ("%+5i"):format(reward), '#FFFFFF', '#0000003F') return end empty(game.sprite_input) empty(game.tile_input) empty(game.extra_input) local controllable = game.R(0x757) == 0 and game.R(0x758) == 0 local x, y = game.getxy(0, 0x86, 0xCE, 0x6D, 0xB5) local powerup = game.R(0x754) local status = game.R(0x756) game.mark_sprite(x + 8, y + 24, -powerup - 1) -- TODO: this will have to do until sprite type embed is added: insert(game.extra_input, (status - 1) * 256) local vx, vy = game.S(0x57), game.S(0x9F) insert(game.extra_input, vx * 16) insert(game.extra_input, vy * 16) if cfg.time_inputs then for i=2,5 do local v = band(trial_frames, lshift(1, i)) == 0 and -181 or 181 insert(game.extra_input, v) end end game.handle_enemies() game.handle_fireballs() --game.handle_blocks() -- blocks being hit. not interactable; we don't care! game.handle_hammers() game.handle_misc() game.handle_tiles() local coins = game.R(0x7ED) * 10 + game.R(0x7EE) local coins_delta = coins - coins_old -- handle wrap-around. if coins_delta < 0 then coins_delta = 100 + coins - coins_old end -- remember that 0 is big mario and 1 is small mario. local powerup_delta = powerup_old - powerup -- 2 is fire mario. local status_delta = clamp(status - status_old, -1, 1) local flagpole_bonus = game.R(0xE) == 4 and cfg.frameskip or 0 --local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus local score_delta = game.get_score() - score_old if score_delta < 0 then score_delta = 0 end local reward_delta = screen_scroll_delta + cfg.score_multiplier * (score_delta + flagpole_bonus) screen_scroll_delta = 0 if cfg.decrement_reward and reward_delta == 0 then reward_delta = reward_delta - 1 end if not ingame_paused and game.get_state() ~= 'win_walking' then -- note that we exclude points gained walking into the castle. -- this way, we avoid adding the timer-based fireworks to our reward, -- which are basically unwanted noise due to the way they trigger. if flagpole_bonus > 0 or controllable then reward = reward + reward_delta end end --gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F') gui.text(89, 16, ("%+5i"):format(reward), '#FFFFFF', '#0000003F') if game.get_state() == 'dead' and state_old ~= 'dead' then --print("dead. lives remaining:", game.R(0x75A, 0)) if game.R(0x75A, 0) == 0 then prepare_reset() end end if game.get_state() == 'lose' then -- this shouldn't happen if we catch the deaths as above. print("ran out of lives.") prepare_reset() end -- lose a point for every frame paused. --if ingame_paused then reward = reward - 1 end if ingame_paused then reward = reward - 402; prepare_reset() end -- if we've run out of time while the game is paused... -- that's cheating! unpause. force_start = ingame_paused and timer == 0 local X = {} for i, v in ipairs(game.sprite_input) do insert(X, v / 256) end for i, v in ipairs(game.extra_input) do insert(X, v / 256) end nn.reshape(X, 1, gcfg.input_size) nn.reshape(game.tile_input, 1, gcfg.tile_count) trial_frames = trial_frames + cfg.frameskip if cfg.enable_network and game.get_state() == 'playing' or ingame_paused then total_frames = total_frames + cfg.frameskip local outputs = network:forward({[nn_x]=X, [nn_tx]=game.tile_input}) local eps = lerp(cfg.eps_start, cfg.eps_stop, total_frames / cfg.eps_frames) if cfg.det_epsilon and random() < eps then local i = floor(random() * #gcfg.jp_lut) + 1 jp = copy(gcfg.jp_lut[i], jp) else local choose = cfg.deterministic and argmax or softchoice local ind = choose(unpack(outputs[nn_z])) jp = copy(gcfg.jp_lut[ind], jp) end if force_start then jp = { up = false, down = false, left = false, right = false, A = false, B = false, start = force_start_old, select = false, } else decisions_made = decisions_made + 1 end end coins_old = coins powerup_old = powerup status_old = status force_start_old = force_start state_old = game.get_state() score_old = game.get_score() end init() while true do gui.text(4, 12, game.get_state(), '#FFFFFF', '#0000003F') while gcfg.bad_states[game.get_state()] do -- mash the start button until we have control. joypad_mash('start') prepare_reset() game.advance() gui.text(4, 12, game.get_state(), '#FFFFFF', '#0000003F') while game.get_state() == "loading" do game.advance() end -- kind of a hack. state_old = game.get_state() end if reset then do_reset() lagless_count = 0 end if not cfg.enable_network then -- infinite time cheat. super handy for testing. if game.R(0xE) == 8 then game.set_timer(667) poketime = true elseif poketime then poketime = false game.set_timer(1) end -- infinite lives. game.W(0x75A, 1) end local doot = jp == nil or lagless_count % cfg.frameskip == 0 doit(not doot) -- jp might still be nil if we're not ingame or we're not playing. if jp ~= nil then joypad.write(1, jp) end game.advance() -- remember, this skips lag frames. lagless_count = lagless_count + 1 end