local globalize = require("strict") -- configuration. --randomseed(11) local cfg = require("config") local gcfg = require("gameconfig") -- state. local epoch_i = 0 local base_params local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled. local trial_neg = true local trial_noise = {} local trial_rewards = {} local trials_remaining = 0 local trial_frames = 0 local total_frames = 0 local force_start = false local force_start_old = false local startsave = savestate.create(1) local poketime = false local max_time local sprite_input = {} local tile_input = {} local extra_input = {} local jp local screen_scroll_delta local reward --local all_rewards = {} local powerup_old local status_old local coins_old local score_old local once = false local reset = true local state_old = '' local last_trial_state -- localize some stuff. local print = print local ipairs = ipairs local pairs = pairs local select = select local abs = math.abs local floor = math.floor local ceil = math.ceil local min = math.min local max = math.max local exp = math.exp local log = math.log local sqrt = math.sqrt local random = math.random local randomseed = math.randomseed local insert = table.insert local remove = table.remove local unpack = table.unpack or unpack local sort = table.sort local R = memory.readbyteunsigned local S = memory.readbyte --signed local W = memory.writebyte local band = bit.band local bor = bit.bor local bxor = bit.bxor local bnot = bit.bnot local lshift = bit.lshift local rshift = bit.rshift local arshift = bit.arshift local rol = bit.rol local ror = bit.ror -- utilities. local function boolean_xor(a, b) if a and b then return false end if not a and not b then return false end return true end local _invlog2 = 1 / log(2) local function log2(x) return log(x) * _invlog2 end local function clamp(x, l, u) return min(max(x, l), u) end local function lerp(a, b, t) return a + (b - a) * clamp(t, 0, 1) end local function argmax(...) local max_i = 0 local max_v = -999999999 for i=1, select("#", ...) do local v = select(i, ...) if v > max_v then max_i = i max_v = v end end return max_i end local function softchoice(...) local t = random() local psum = 0 for i=1, select("#", ...) do local p = select(i, ...) psum = psum + p if t < psum then return i end end end local function argmax2(t) return t[1] > t[2] end local function rchoice2(t) return t[1] > random() end local function rbool() return 0.5 >= random() end local function empty(t) for k, _ in pairs(t) do t[k] = nil end return t end local function calc_mean_dev(x) local mean = 0 for i, v in ipairs(x) do mean = mean + v / #x end local dev = 0 for i, v in ipairs(x) do local delta = v - mean dev = dev + delta * delta / #x end return mean, sqrt(dev) end local function normalize(x, out) out = out or x local mean, dev = calc_mean_dev(x) if dev <= 0 then dev = 1 end for i, v in ipairs(x) do out[i] = (v - mean) / dev end return out end local function normalize_wrt(x, s, out) out = out or x local mean, dev = calc_mean_dev(s) if dev <= 0 then dev = 1 end for i, v in ipairs(x) do out[i] = (v - mean) / dev end return out end -- network parameters. package.loaded['nn'] = nil -- DEBUG local nn = require("nn") local network local nn_x, nn_tx, nn_ty, nn_y, nn_z local function make_network(input_size) nn_x = nn.Input({input_size}) nn_tx = nn.Input({gcfg.tile_count}) nn_ty = nn_tx:feed(nn.Embed(256, 2)) nn_y = nn.Merge() nn_x:feed(nn_y) nn_ty:feed(nn_y) nn_y = nn_y:feed(nn.Dense(128)) if cfg.deterministic then nn_y = nn_y:feed(nn.Relu()) else nn_y = nn_y:feed(nn.Gelu()) end nn_z = nn_y nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut)) nn_z = nn_z:feed(nn.Softmax()) return nn.Model({nn_x, nn_tx}, {nn_z}) end -- and here we go with the game stuff. -- disassembly used for reference: -- https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM local function get_timer() return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA) end local function get_score() return R(0x7DE) * 10000 + R(0x7DF) * 1000 + R(0x7E0) * 100 + R(0x7E1) * 10 + R(0x7E2) end local function set_timer(time) W(0x7F8, floor(time / 100)) W(0x7F9, floor((time / 10) % 10)) W(0x7FA, floor(time % 10)) end local function mark_sprite(x, y, t) if x < 0 or x >= 256 or y < 0 or y > 224 then sprite_input[#sprite_input+1] = 0 sprite_input[#sprite_input+1] = 0 sprite_input[#sprite_input+1] = 0 else sprite_input[#sprite_input+1] = x sprite_input[#sprite_input+1] = y sprite_input[#sprite_input+1] = t end if t == 0 then return end if cfg.enable_overlay then gui.box(x-4, y-4, x+4, y+4) --gui.text(x-2, y-3, tostring(i), '#FFFFFF', '#00000000') gui.text(x-13, y-3-9, ("%+04i"):format(t), '#FFFFFF', '#0000003F') --gui.text(x-5, y-3+9, ("%02X"):format(x), '#FFFFFF', '#0000003F') end end local function mark_tile(x, y, t) tile_input[#tile_input+1] = t if t == 0 then return end if cfg.enable_overlay then gui.box(x-8, y-8, x+8, y+8) gui.text(x-5, y-3, ("%02X"):format(t), '#FFFFFF', '#00000000') end end local function getxy(i, x_addr, y_addr, pageloc_addr, hipos_addr) local spl_l = R(0x71A) local spl_r = R(0x71B) local sx_l = R(0x71C) local sx_r = R(0x71D) local x = R(x_addr + i) local y = R(y_addr + i) local sx, sy = x, y if pageloc_addr ~= nil then local page = R(pageloc_addr + i) sx = sx - sx_l - (spl_l - page) * 256 else sx = sx - sx_l end if hipos_addr ~= nil then local hipos = S(hipos_addr + i) sy = sy + (hipos - 1) * 256 end return sx, sy end local function paused() return band(R(0x776), 1) end local function get_state() if R(0xE) == 0xFF then return 'power' end if R(0x774) > 0 then return 'lagging' end if R(0x7A2) > 0 then return 'waiting_demo' end if R(0x717) > 0 then return 'playing_demo' end -- if R(0x770) == 0xFF then return 'power' end if paused() ~= 0 then return 'paused' end if R(0xE) == 0 then return 'world_screen' end -- if R(0x712) == 1 then return 'deadmusic' end if R(0x7CA) == 0x94 then return 'dead' end if R(0xE) == 4 then return 'win_flagpole' end if R(0xE) == 5 then return 'win_walking' end if R(0xE) == 6 then return 'lose' end -- if R(0x770) == 0 then return 'not_playing' end if R(0x770) == 2 then return 'win_castle' end if R(0x772) == 2 then return 'no_control' end if R(0x772) == 3 then return 'playing' end if R(0x770) == 1 then return 'loading' end if R(0x770) == 3 then return 'lose' end return 'unknown' end local function advance() emu.frameadvance() while emu.lagged() do emu.frameadvance() end -- skip lag frames. while R(0x774) > 0 do emu.frameadvance() end -- also lag frames. end local function handle_enemies() -- enemies, flagpole for i = 0, 5 do local x, y = getxy(i, 0x87, 0xCF, 0x6E, 0xB6) x, y = x + 8, y + 16 local tid = R(0x16 + i) local flags = R(0xF + i) --local offscr = R(0x3D8 + i) local invisible = tid < 0x10 and flags == 0 if tid == 0x30 then y = y - 8 end -- flagpole flag if tid == 0x31 then y = y - 8 end -- castle flag if tid == 0x16 then x, y = x - 4, y - 12 end -- fireworks if tid >= 0x24 and tid <= 0x29 then x, y = x + 16, y - 12 end -- moving platforms if tid == 0x2D then x, y = x, y end -- bowser (TODO: determine head or body) if tid == 0x15 then x, y = x, y - 12 end -- bowser fire if tid == 0x32 then x, y = x, y - 8 end -- spring -- tid == 0x35 -- toad if tid == 0x1D or tid == 0x1B then -- rotating fire bars x, y = x - 4, y - 12 -- this is a mess... gotta find out its rotation and then project. -- TODO: handle long fire bars too local rot = R(0xA0 + i) --* 0x100 + R(0x58 + i) gui.text(x-13, y-3+9, ("%04X"):format(rot), '#FFFFFF', '#0000003F') local x_off, y_off = gcfg.rotation_offsets[rot*2+1], gcfg.rotation_offsets[rot*2+2] x, y = x + x_off, y + y_off end if invisible then mark_sprite(0, 0, 0) else mark_sprite(x, y, tid + 1) end end end local function handle_fireballs() for i = 0, 1 do local x, y = getxy(i, 0x8D, 0xD5, 0x74, 0xBC) x, y = x + 4, y + 4 local state = R(0x24 + i) local invisible = state == 0 if invisible then mark_sprite(0, 0, 0) else mark_sprite(x, y, 257) end end end local function handle_blocks() for i = 0, 3 do local x, y = getxy(i, 0x8F, 0xD7, 0x76, 0xBE) x, y = x + 8, y + 8 local state = R(0x26 + i) local invisible = state == 0 if invisible then mark_sprite(0, 0, 0) else mark_sprite(x, y, 258) end end end local function handle_hammers() -- hammers, coins, score bonus text... for i = 0, 8 do local x, y = getxy(i, 0x93, 0xDB, 0x7A, 0xC2) x, y = x + 8, y + 8 local state = R(0x2A + i) -- skip coin effect states. not interactable; we don't care! if state ~= 0 and state >= 0x30 then mark_sprite(x, y, state + 1) else mark_sprite(0, 0, 0) end end end local function handle_misc() for i = 0, 0 do local x, y = getxy(i, 0x9C, 0xE4, 0x83, 0xCB) x, y = x + 8, y + 8 local state = R(0x33 + i) if state ~= 0 then mark_sprite(x, y, state + 1) else mark_sprite(0, 0, 0) end end end local function handle_tiles() --local tile_col = R(0x6A0) local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16 local tile_scroll_remainder = R(0x73F) % 16 extra_input[#extra_input+1] = tile_scroll_remainder for y = 0, 12 do for x = 0, 16 do local col = (x + tile_scroll) % 32 local t if col < 16 then t = R(0x500 + y * 16 + (col % 16)) else t = R(0x5D0 + y * 16 + (col % 16)) end local sx = x * 16 + 8 - tile_scroll_remainder local sy = y * 16 + 40 mark_tile(sx, sy, t) end end end -- learning and evaluation. local function prepare_epoch() print('preparing epoch '..tostring(epoch_i)..'.') base_params = network:collect() empty(trial_noise) empty(trial_rewards) -- TODO: save memory. generate noise as needed by saving the seed -- (the os.time() as of here) and calling nn.normal() each trial. for i = 1, cfg.epoch_trials do local noise = nn.zeros(#base_params) -- NOTE: change in implementation: deviation is multiplied here -- and ONLY here now. for j = 1, #base_params do noise[j] = cfg.deviation * nn.normal() end trial_noise[i] = noise end trial_i = -1 end local function load_next_pair() trial_i = trial_i + 1 if trial_i == 0 and not cfg.unperturbed_trial then trial_i = 1 trial_neg = true end local W = nn.copy(base_params) if trial_i > 0 then if trial_neg then if not cfg.defer_prints then print('trial', trial_i, 'positive') end local noise = trial_noise[trial_i] for i, v in ipairs(base_params) do W[i] = v + noise[i] end else trial_i = trial_i - 1 if not cfg.defer_prints then print('trial', trial_i, 'positive') end local noise = trial_noise[trial_i] for i, v in ipairs(base_params) do W[i] = v - noise[i] end end trial_neg = not trial_neg else if not cfg.defer_prints then print("test trial") end end network:distribute(W) end local function load_next_trial() if cfg.negate_trials then return load_next_pair() end trial_i = trial_i + 1 local W = nn.copy(base_params) if trial_i == 0 and not cfg.unperturbed_trial then trial_i = 1 end if trial_i > 0 then print('loading trial', trial_i) local noise = trial_noise[trial_i] for i, v in ipairs(base_params) do W[i] = v + noise[i] end else print("test trial") end network:distribute(W) end local function fitness_shaping(rewards) -- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py local decreasing = nn.copy(rewards) sort(decreasing, function(a, b) return a > b end) local shaped_returns = {} local lamb = #rewards local denom = 0 for i, v in ipairs(rewards) do local l = log2(lamb / 2 + 1) local r = log2(nn.indexof(decreasing, v)) denom = denom + max(0, l - r) end for i, v in ipairs(rewards) do local l = log2(lamb / 2 + 1) local r = log2(nn.indexof(decreasing, v)) local numer = max(0, l - r) insert(shaped_returns, numer / denom + 1 / lamb) end return shaped_returns end local function unperturbed_rank(rewards, unperturbed_reward) -- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py local nth_place = 1 for i, v in ipairs(rewards) do if v > unperturbed_reward then nth_place = nth_place + 1 end end return nth_place end local function learn_from_epoch() print() --print('rewards:', trial_rewards) --for _, v in ipairs(trial_rewards) do insert(all_rewards, v) end if cfg.unperturbed_trial then local nth_place = unperturbed_rank(trial_rewards, trial_rewards[0]) -- a rank of 1 means our gradient is uninformative. print(("test trial: %d out of %d"):format(nth_place, #trial_rewards)) end local step = nn.zeros(#base_params) -- new stuff local best_rewards if cfg.negate_trials then -- select one (the best) reward of each pos/neg pair. best_rewards = {} for i = 1, cfg.epoch_trials do local ind = (i - 1) * 2 + 1 local pos = trial_rewards[ind + 0] local neg = trial_rewards[ind + 1] best_rewards[i] = max(pos, neg) end else best_rewards = nn.copy(trial_rewards) end local indices = {} for i = 1, #best_rewards do indices[i] = i end sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end) --print("indices:", indices) for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end print("best trials:", indices) local top_rewards = {} for i = 1, #trial_rewards do top_rewards[i] = 0 end for _, ind in ipairs(indices) do local sind = (ind - 1) * 2 + 1 top_rewards[sind + 0] = trial_rewards[sind + 0] top_rewards[sind + 1] = trial_rewards[sind + 1] end print("top:", top_rewards) local _, reward_dev = calc_mean_dev(top_rewards) --print("mean, dev:", _, reward_dev) if reward_dev == 0 then reward_dev = 1 end for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end -- NOTE: step no longer directly incorporates learning_rate. for i = 1, cfg.epoch_trials do local ind = (i - 1) * 2 + 1 local pos = top_rewards[ind + 0] local neg = top_rewards[ind + 1] local reward = pos - neg local noise = trial_noise[i] for j, v in ipairs(noise) do step[j] = step[j] + reward * v / cfg.epoch_top_trials end end local step_mean, step_dev = calc_mean_dev(step) if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end --print("step stddev:", step_dev) print("full step stddev:", cfg.learning_rate * step_dev) for i, v in ipairs(base_params) do base_params[i] = v + cfg.learning_rate * step[i] - cfg.weight_decay * v end if cfg.enable_network then network:distribute(base_params) network:save() else print("note: not updating weights in playable mode.") end print() end local function joypad_mash(button) local jp_mash = { up = false, down = false, left = false, right = false, A = false, B = false, select = false, start = false, } assert(jp_mash[button] == false, "invalid button: "..tostring(button), 1) jp_mash[button] = emu.framecount() % 2 == 1 joypad.write(1, jp_mash) end local function do_reset() local state = get_state() -- be a little more descriptive. if state == 'dead' and get_timer() == 0 then state = 'timeup' end if trial_i >= 0 and cfg.defer_prints then if trial_i == 0 then print('test trial reward:', reward, "("..state..")") elseif cfg.negate_trials then --local dir = trial_neg and "negative" or "positive" --print('trial', trial_i, dir, 'reward:', reward, "("..state..")") if trial_neg then local pos = trial_rewards[#trial_rewards] local neg = reward local fmt = "trial %i rewards: %+i, %+i (%s, %s)" print(fmt:format(trial_i, pos, neg, last_trial_state, state)) end last_trial_state = state else print('trial', trial_i, 'reward:', reward, "("..state..")") end else print("reward:", reward, "("..state..")") end if trial_i >= 0 then if trial_i == 0 or not cfg.negate_trials then trial_rewards[trial_i] = reward else trial_rewards[#trial_rewards + 1] = reward end end if epoch_i == 0 or (trial_i == cfg.epoch_trials and trial_neg) then if epoch_i > 0 then learn_from_epoch() end epoch_i = epoch_i + 1 prepare_epoch() end if get_state() == 'loading' then advance() end -- kind of a hack. reward = 0 powerup_old = R(0x754) status_old = R(0x756) coins_old = R(0x7ED) * 10 + R(0x7EE) score_old = get_score() -- set number of lives. (mario gets n+1 chances) W(0x75A, cfg.starting_lives) if cfg.start_big then -- make mario "super". W(0x754, 0) W(0x756, 1) end --max_time = min(log(epoch_i) * 10 + 100, cfg.cap_time) --max_time = min(8 * sqrt(360 / cfg.epoch_trials * (epoch_i - 1)) + 100, cfg.cap_time) max_time = min(6 * sqrt(480 / cfg.epoch_trials * (epoch_i - 1)) + 60, cfg.cap_time) max_time = ceil(max_time) --max_time = cfg.cap_time if once then savestate.load(startsave) else savestate.save(startsave) end once = true jp = nil screen_scroll_delta = 0 emu.frameadvance() -- prevents emulator from quirking up. load_next_trial() reset = false end local function init() network = make_network(gcfg.input_size) network:reset() network:print() print("parameters:", network.n_param) if cfg.init_zeros then local W = network:collect() for i, w in ipairs(W) do W[i] = 0 end network:distribute(W) end emu.poweron() emu.unpause() emu.speedmode("turbo") while emu.framecount() < 195 do -- FIXME: don't hardcode this. joypad_mash('start') emu.frameadvance() end print(emu.framecount()) local res, err = pcall(network.load, network) if res == false then print(err) end end local function doit(dummy) local ingame_paused = get_state() == "paused" -- every few frames mario stands still, forcibly decrease the timer. -- this includes having the game paused. -- TODO: more robust. doesn't detect moonwalking against a wall. -- well, that shouldn't happen anymore now that i've disabled left+right. local timer = get_timer() if ingame_paused or random() > 1 - cfg.timer_loser and R(0x1D) == 0 and R(0x57) == 0 then timer = timer - 1 end timer = clamp(timer, 0, max_time) if cfg.enable_network then set_timer(timer) end local tf0 = total_frames % 1000 local tf1 = (total_frames % 1000000 - tf0) / 1000 local tf2 = (total_frames - tf0 - tf1) / 1000000 gui.text(12, 212, ("%03i,%03i,%03i"):format(tf2,tf1,tf0), '#FFFFFF', '#0000003F') screen_scroll_delta = screen_scroll_delta + R(0x775) if dummy == true then -- don't invoke AI this frame. (keep holding the old inputs) gui.text(89, 16, ("%+5i"):format(reward), '#FFFFFF', '#0000003F') return end empty(sprite_input) empty(tile_input) empty(extra_input) -- TODO: check if mario is in a playable state. local x, y = getxy(0, 0x86, 0xCE, 0x6D, 0xB5) local powerup = R(0x754) local status = R(0x756) mark_sprite(x + 8, y + 24, -powerup - 1) local vx, vy = S(0x57), S(0x9F) insert(extra_input, vx) insert(extra_input, vy) if cfg.time_inputs then for i=2,5 do insert(extra_input, band(total_frames, lshift(1, i))) end end handle_enemies() handle_fireballs() -- blocks being hit. not interactable; we don't care! --handle_blocks() handle_hammers() handle_misc() handle_tiles() local coins = R(0x7ED) * 10 + R(0x7EE) local coins_delta = coins - coins_old -- handle wrap-around. if coins_delta < 0 then coins_delta = 100 + coins - coins_old end -- remember that 0 is big mario and 1 is small mario. local powerup_delta = powerup_old - powerup -- 2 is fire mario. local status_delta = clamp(status - status_old, -1, 1) local flagpole_bonus = R(0xE) == 4 and cfg.frameskip or 0 --local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus local score_delta = get_score() - score_old if score_delta < 0 then score_delta = 0 end local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus screen_scroll_delta = 0 if cfg.decrement_reward and reward_delta == 0 then reward_delta = reward_delta - 1 end if not ingame_paused then reward = reward + reward_delta end --gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F') gui.text(89, 16, ("%+5i"):format(reward), '#FFFFFF', '#0000003F') if get_state() == 'dead' and state_old ~= 'dead' then --print("dead. lives remaining:", R(0x75A, 0)) if R(0x75A, 0) == 0 then reset = true end end if get_state() == 'lose' then -- this shouldn't happen if we catch the deaths as above. print("ran out of lives.") reset = true end -- lose a point for every frame paused. --if ingame_paused then reward = reward - 1 end if ingame_paused then reward = reward - 402; reset = true end -- if we've run out of time while the game is paused... -- that's cheating! unpause. force_start = ingame_paused and timer == 0 local X = {} for i, v in ipairs(sprite_input) do insert(X, v / 256) end for i, v in ipairs(extra_input) do insert(X, v / 256) end nn.reshape(X, 1, gcfg.input_size) nn.reshape(tile_input, 1, gcfg.tile_count) if cfg.enable_network and get_state() == 'playing' or ingame_paused then total_frames = total_frames + cfg.frameskip local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input}) local eps = lerp(cfg.eps_start, cfg.eps_stop, total_frames / cfg.eps_frames) if cfg.det_epsilon and random() < eps then local i = floor(random() * #gcfg.jp_lut) + 1 jp = nn.copy(gcfg.jp_lut[i], jp) else local choose = cfg.deterministic and argmax or softchoice local ind = choose(unpack(outputs[nn_z])) jp = nn.copy(gcfg.jp_lut[ind], jp) end if force_start then jp = { up = false, down = false, left = false, right = false, A = false, B = false, start = force_start_old, select = false, } end end coins_old = coins powerup_old = powerup status_old = status force_start_old = force_start state_old = get_state() score_old = get_score() end init() while true do gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F') while gcfg.bad_states[get_state()] do -- mash the start button until we have control. joypad_mash('start') reset = true advance() gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F') while get_state() == "loading" do advance() end -- kind of a hack. state_old = get_state() end if reset then do_reset() end if not cfg.enable_network then -- infinite time cheat. super handy for testing. if R(0xE) == 8 then set_timer(667) poketime = true elseif poketime then poketime = false set_timer(1) end -- infinite lives. W(0x75A, 1) end -- FIXME: if the game lags then we might miss our frame to change inputs! -- don't rely on emu.framecount. local doot = jp == nil or emu.framecount() % cfg.frameskip == 0 doit(not doot) -- jp might still be nil if we're not ingame or we're not playing. if jp ~= nil then joypad.write(1, jp) end advance() end