From a836314b8b30d55349fbfc8cfce56cd8e11f5ed8 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Sat, 12 May 2018 22:38:51 +0200 Subject: [PATCH] refactor game and utility functions --- gameconfig.lua | 35 ---- main.lua | 493 ++++++++----------------------------------------- nn.lua | 36 +--- smb.lua | 290 +++++++++++++++++++++++++++++ util.lua | 176 ++++++++++++++++++ 5 files changed, 549 insertions(+), 481 deletions(-) create mode 100644 smb.lua create mode 100644 util.lua diff --git a/gameconfig.lua b/gameconfig.lua index 34e1b5a..775cae7 100644 --- a/gameconfig.lua +++ b/gameconfig.lua @@ -66,41 +66,6 @@ local gcfg = { select = false, start = false, B = false, A = false, }, }, - - rotation_offsets = { -- FIXME: not all of these are pixel-perfect. - 0, -40, -- 0x00 - 6, -38, - 15, -37, - 22, -32, - 28, -28, - 32, -22, - 37, -14, - 39, -6, - 40, 0, -- 0x08 - 38, 7, - 37, 15, - 33, 23, - 27, 29, - 22, 33, - 14, 37, - 6, 39, - 0, 41, -- 0x10 - -7, 40, - -16, 38, - -22, 34, - -28, 28, - -34, 23, - -38, 16, - -40, 8, - -40, -0, -- 0x18 - -40, -6, - -38, -14, - -34, -22, - -28, -28, - -22, -32, - -16, -36, - -8, -38, - }, } return setmetatable(gcfg, { diff --git a/main.lua b/main.lua index c64551c..1f2e0b7 100644 --- a/main.lua +++ b/main.lua @@ -31,10 +31,6 @@ local startsave = savestate.create(1) local poketime = false local max_time -local sprite_input = {} -local tile_input = {} -local extra_input = {} - local jp local screen_scroll_delta @@ -74,9 +70,6 @@ local insert = table.insert local remove = table.remove local unpack = table.unpack or unpack local sort = table.sort -local R = memory.readbyteunsigned -local S = memory.readbyte --signed -local W = memory.writebyte local band = bit.band local bor = bit.bor @@ -90,6 +83,16 @@ local ror = bit.ror local gui = gui +local util = require("util") +local argmax = util.argmax +local calc_mean_dev = util.calc_mean_dev +local clamp = util.clamp +local copy = util.copy +local empty = util.empty +local lerp = util.lerp +local softchoice = util.softchoice +local unperturbed_rank = util.unperturbed_rank + -- utilities. local log_map = { @@ -126,97 +129,6 @@ local function log_csv(t) f:close() end -local function boolean_xor(a, b) - if a and b then return false end - if not a and not b then return false end - return true -end - -local function signbyte(x) - if x >= 128 then x = 256 - x end - return x -end - -local _invlog2 = 1 / log(2) -local function log2(x) return log(x) * _invlog2 end - -local function clamp(x, l, u) return min(max(x, l), u) end - -local function lerp(a, b, t) return a + (b - a) * clamp(t, 0, 1) end - -local function argmax(...) - local max_i = 0 - local max_v = -999999999 - for i=1, select("#", ...) do - local v = select(i, ...) - if v > max_v then - max_i = i - max_v = v - end - end - return max_i -end - -local function softchoice(...) - local t = random() - local psum = 0 - for i=1, select("#", ...) do - local p = select(i, ...) - psum = psum + p - if t < psum then - return i - end - end -end - -local function argmax2(t) - return t[1] > t[2] -end - -local function rchoice2(t) - return t[1] > random() -end - -local function rbool() - return 0.5 >= random() -end - -local function empty(t) - for k, _ in pairs(t) do t[k] = nil end - return t -end - -local function calc_mean_dev(x) - local mean = 0 - for i, v in ipairs(x) do - mean = mean + v / #x - end - - local dev = 0 - for i, v in ipairs(x) do - local delta = v - mean - dev = dev + delta * delta / #x - end - - return mean, sqrt(dev) -end - -local function normalize(x, out) - out = out or x - local mean, dev = calc_mean_dev(x) - if dev <= 0 then dev = 1 end - for i, v in ipairs(x) do out[i] = (v - mean) / dev end - return out -end - -local function normalize_wrt(x, s, out) - out = out or x - local mean, dev = calc_mean_dev(s) - if dev <= 0 then dev = 1 end - for i, v in ipairs(x) do out[i] = (v - mean) / dev end - return out -end - -- network parameters. package.loaded['nn'] = nil -- DEBUG @@ -247,219 +159,10 @@ local function make_network(input_size) end -- and here we go with the game stuff. +-- which was all refactored out, so this comment looks a little silly now. --- disassembly used for reference: --- https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM - -local function get_timer() - return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA) -end - -local function get_score() - return R(0x7DE) * 10000 + - R(0x7DF) * 1000 + - R(0x7E0) * 100 + - R(0x7E1) * 10 + - R(0x7E2) -end - -local function set_timer(time) - W(0x7F8, floor(time / 100)) - W(0x7F9, floor((time / 10) % 10)) - W(0x7FA, floor(time % 10)) -end - -local function mark_sprite(x, y, t) - if x < 0 or x >= 256 or y < 0 or y > 224 then - sprite_input[#sprite_input+1] = 0 - sprite_input[#sprite_input+1] = 0 - sprite_input[#sprite_input+1] = 0 - else - sprite_input[#sprite_input+1] = x - sprite_input[#sprite_input+1] = y - sprite_input[#sprite_input+1] = t - end - if t == 0 then return end - if cfg.enable_overlay then - gui.box(x-4, y-4, x+4, y+4) - --gui.text(x-2, y-3, tostring(i), '#FFFFFF', '#00000000') - gui.text(x-13, y-3-9, ("%+04i"):format(t), '#FFFFFF', '#0000003F') - --gui.text(x-5, y-3+9, ("%02X"):format(x), '#FFFFFF', '#0000003F') - end -end - -local function mark_tile(x, y, t) - tile_input[#tile_input+1] = t - if t == 0 then return end - if cfg.enable_overlay then - gui.box(x-8, y-8, x+8, y+8) - gui.text(x-5, y-3, ("%02X"):format(t), '#FFFFFF', '#00000000') - end -end - -local function getxy(i, x_addr, y_addr, pageloc_addr, hipos_addr) - local spl_l = R(0x71A) - local spl_r = R(0x71B) - local sx_l = R(0x71C) - local sx_r = R(0x71D) - - local x = R(x_addr + i) - local y = R(y_addr + i) - local sx, sy = x, y - if pageloc_addr ~= nil then - local page = R(pageloc_addr + i) - sx = sx - sx_l - (spl_l - page) * 256 - else - sx = sx - sx_l - end - if hipos_addr ~= nil then - local hipos = S(hipos_addr + i) - sy = sy + (signbyte(hipos) - 1) * 256 - end - - return sx, sy -end - -local function paused() return band(R(0x776), 1) end - -local function get_state() - if R(0xE) == 0xFF then return 'power' end - if R(0x774) > 0 then return 'lagging' end - if R(0x7A2) > 0 then return 'waiting_demo' end - if R(0x717) > 0 then return 'playing_demo' end --- if R(0x770) == 0xFF then return 'power' end - if paused() ~= 0 then return 'paused' end - if R(0xE) == 0 then return 'world_screen' end --- if R(0x712) == 1 then return 'deadmusic' end - if R(0x7CA) == 0x94 then return 'dead' end - if R(0xE) == 4 then return 'win_flagpole' end - if R(0xE) == 5 then return 'win_walking' end - if R(0xE) == 6 then return 'lose' end --- if R(0x770) == 0 then return 'not_playing' end - if R(0x770) == 2 then return 'win_castle' end - if R(0x772) == 2 then return 'no_control' end - if R(0x772) == 3 then return 'playing' end - if R(0x770) == 1 then return 'loading' end - if R(0x770) == 3 then return 'lose' end - return 'unknown' -end - -local function advance() - emu.frameadvance() - while emu.lagged() do emu.frameadvance() end -- skip lag frames. - while R(0x774) > 0 do emu.frameadvance() end -- also lag frames. -end - -local function handle_enemies() - -- enemies, flagpole - for i = 0, 5 do - local x, y = getxy(i, 0x87, 0xCF, 0x6E, 0xB6) - x, y = x + 8, y + 16 - local tid = R(0x16 + i) - local flags = R(0xF + i) - --local offscr = R(0x3D8 + i) - local invisible = tid < 0x10 and flags == 0 - if tid == 0x30 then y = y - 8 end -- flagpole flag - if tid == 0x31 then y = y - 8 end -- castle flag - if tid == 0x16 then x, y = x - 4, y - 12 end -- fireworks - if tid >= 0x24 and tid <= 0x29 then x, y = x + 16, y - 12 end -- moving platforms - if tid == 0x2D then x, y = x, y end -- bowser (TODO: determine head or body) - if tid == 0x15 then x, y = x, y - 12 end -- bowser fire - if tid == 0x32 then x, y = x, y - 8 end -- spring - -- tid == 0x35 -- toad - if tid == 0x1D or tid == 0x1B then -- rotating fire bars - x, y = x - 4, y - 12 - -- this is a mess... gotta find out its rotation and then project. - -- TODO: handle long fire bars too - local rot = R(0xA0 + i) --* 0x100 + R(0x58 + i) - gui.text(x-13, y-3+9, ("%04X"):format(rot), '#FFFFFF', '#0000003F') - local x_off, y_off = gcfg.rotation_offsets[rot*2+1], gcfg.rotation_offsets[rot*2+2] - x, y = x + x_off, y + y_off - end - if invisible then - mark_sprite(0, 0, 0) - else - mark_sprite(x, y, tid + 1) - end - end -end - -local function handle_fireballs() - for i = 0, 1 do - local x, y = getxy(i, 0x8D, 0xD5, 0x74, 0xBC) - x, y = x + 4, y + 4 - local state = R(0x24 + i) - local invisible = state == 0 - if invisible then - mark_sprite(0, 0, 0) - else - mark_sprite(x, y, 257) - end - end -end - -local function handle_blocks() - for i = 0, 3 do - local x, y = getxy(i, 0x8F, 0xD7, 0x76, 0xBE) - x, y = x + 8, y + 8 - local state = R(0x26 + i) - local invisible = state == 0 - if invisible then - mark_sprite(0, 0, 0) - else - mark_sprite(x, y, 258) - end - end -end - -local function handle_hammers() - -- hammers, coins, score bonus text... - for i = 0, 8 do - local x, y = getxy(i, 0x93, 0xDB, 0x7A, 0xC2) - x, y = x + 8, y + 8 - local state = R(0x2A + i) - -- skip coin effect states. not interactable; we don't care! - if state ~= 0 and state >= 0x30 then - mark_sprite(x, y, state + 1) - else - mark_sprite(0, 0, 0) - end - end -end - -local function handle_misc() - for i = 0, 0 do - local x, y = getxy(i, 0x9C, 0xE4, 0x83, 0xCB) - x, y = x + 8, y + 8 - local state = R(0x33 + i) - if state ~= 0 then - mark_sprite(x, y, state + 1) - else - mark_sprite(0, 0, 0) - end - end -end - -local function handle_tiles() - --local tile_col = R(0x6A0) - local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16 - local tile_scroll_remainder = R(0x73F) % 16 - extra_input[#extra_input+1] = tile_scroll_remainder - for y = 0, 12 do - for x = 0, 16 do - local col = (x + tile_scroll) % 32 - local t - if col < 16 then - t = R(0x500 + y * 16 + (col % 16)) - else - t = R(0x5D0 + y * 16 + (col % 16)) - end - local sx = x * 16 + 8 - tile_scroll_remainder - local sy = y * 16 + 40 - mark_tile(sx, sy, t) - end - end -end +local game = require("smb") +game.overlay = cfg.enable_overlay -- learning and evaluation. @@ -514,7 +217,7 @@ local function load_next_pair() trial_neg = true end - local W = nn.copy(base_params) + local W = copy(base_params) if trial_i > 0 then if trial_neg then @@ -544,7 +247,7 @@ end local function load_next_trial() if cfg.negate_trials then return load_next_pair() end trial_i = trial_i + 1 - local W = nn.copy(base_params) + local W = copy(base_params) if trial_i == 0 and not cfg.unperturbed_trial then trial_i = 1 end @@ -560,41 +263,6 @@ local function load_next_trial() network:distribute(W) end -local function fitness_shaping(rewards) - -- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py - local decreasing = nn.copy(rewards) - sort(decreasing, function(a, b) return a > b end) - local shaped_returns = {} - local lamb = #rewards - - local denom = 0 - for i, v in ipairs(rewards) do - local l = log2(lamb / 2 + 1) - local r = log2(nn.indexof(decreasing, v)) - denom = denom + max(0, l - r) - end - - for i, v in ipairs(rewards) do - local l = log2(lamb / 2 + 1) - local r = log2(nn.indexof(decreasing, v)) - local numer = max(0, l - r) - insert(shaped_returns, numer / denom + 1 / lamb) - end - - return shaped_returns -end - -local function unperturbed_rank(rewards, unperturbed_reward) - -- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py - local nth_place = 1 - for i, v in ipairs(rewards) do - if v > unperturbed_reward then - nth_place = nth_place + 1 - end - end - return nth_place -end - local function learn_from_epoch() print() --print('rewards:', trial_rewards) @@ -632,7 +300,7 @@ local function learn_from_epoch() best_rewards[i] = max(pos, neg) end else - best_rewards = nn.copy(trial_rewards) + best_rewards = copy(trial_rewards) end local indices = {} @@ -750,15 +418,6 @@ local function learn_from_epoch() test_trial = current_cost or 0, } - -- trying a heuristic... - --[[ - if delta_std < trial_std then - cfg.deviation = cfg.deviation * 0.933 - -- this one might be bad... - cfg.weight_decay = cfg.weight_decay * 0.933 - end - --]] - if cfg.enable_network then network:distribute(base_params) network:save(cfg.params_fn) @@ -786,9 +445,9 @@ local function joypad_mash(button) end local function do_reset() - local state = get_state() + local state = game.get_state() -- be a little more descriptive. - if state == 'dead' and get_timer() == 0 then state = 'timeup' end + if state == 'dead' and game.get_timer() == 0 then state = 'timeup' end if trial_i >= 0 and cfg.defer_prints then if trial_i == 0 then @@ -825,20 +484,20 @@ local function do_reset() prepare_epoch() end - if get_state() == 'loading' then advance() end -- kind of a hack. + if game.get_state() == 'loading' then game.advance() end -- kind of a hack. reward = 0 - powerup_old = R(0x754) - status_old = R(0x756) - coins_old = R(0x7ED) * 10 + R(0x7EE) - score_old = get_score() + powerup_old = game.R(0x754) + status_old = game.R(0x756) + coins_old = game.R(0x7ED) * 10 + game.R(0x7EE) + score_old = game.get_score() -- set number of lives. (mario gets n+1 chances) - W(0x75A, cfg.starting_lives) + game.W(0x75A, cfg.starting_lives) if cfg.start_big then -- make mario "super". - W(0x754, 0) - W(0x756, 1) + game.W(0x754, 0) + game.W(0x756, 1) end --max_time = min(log(epoch_i) * 10 + 100, cfg.cap_time) @@ -896,20 +555,20 @@ local function prepare_reset() end local function doit(dummy) - local ingame_paused = get_state() == "paused" + local ingame_paused = game.get_state() == "paused" -- every few frames mario stands still, forcibly decrease the timer. -- this includes having the game paused. -- TODO: more robust. doesn't detect moonwalking against a wall. -- well, that shouldn't happen anymore now that i've disabled left+right. - local timer = get_timer() - if ingame_paused or random() > 1 - cfg.timer_loser and R(0x1D) == 0 and R(0x57) == 0 then + local timer = game.get_timer() + if ingame_paused or random() > 1 - cfg.timer_loser and game.R(0x1D) == 0 and game.R(0x57) == 0 then timer = timer - 1 end if not cfg.playback_mode then timer = clamp(timer, 0, max_time) if cfg.enable_network then - set_timer(timer) + game.set_timer(timer) end end @@ -918,7 +577,7 @@ local function doit(dummy) local tf2 = (total_frames - tf0 - tf1) / 1000000 gui.text(12, 212, ("%03i,%03i,%03i"):format(tf2,tf1,tf0), '#FFFFFF', '#0000003F') - screen_scroll_delta = screen_scroll_delta + R(0x775) + screen_scroll_delta = screen_scroll_delta + game.R(0x775) if dummy == true then -- don't invoke AI this frame. (keep holding the old inputs) @@ -926,37 +585,35 @@ local function doit(dummy) return end - empty(sprite_input) - empty(tile_input) - empty(extra_input) + empty(game.sprite_input) + empty(game.tile_input) + empty(game.extra_input) -- TODO: check if mario is in a playable state. - local x, y = getxy(0, 0x86, 0xCE, 0x6D, 0xB5) - local powerup = R(0x754) - local status = R(0x756) - mark_sprite(x + 8, y + 24, -powerup - 1) + local x, y = game.getxy(0, 0x86, 0xCE, 0x6D, 0xB5) + local powerup = game.R(0x754) + local status = game.R(0x756) + game.mark_sprite(x + 8, y + 24, -powerup - 1) - local vx, vy = S(0x57), S(0x9F) - -- i shouldn't need to do this if it's signed, but apparently... - insert(extra_input, signbyte(vx) * 16) - insert(extra_input, signbyte(vy) * 16) + local vx, vy = game.S(0x57), game.S(0x9F) + insert(game.extra_input, vx * 16) + insert(game.extra_input, vy * 16) if cfg.time_inputs then for i=2,5 do local v = band(trial_frames, lshift(1, i)) == 0 and -181 or 181 - insert(extra_input, v) + insert(game.extra_input, v) end end - handle_enemies() - handle_fireballs() - -- blocks being hit. not interactable; we don't care! - --handle_blocks() - handle_hammers() - handle_misc() - handle_tiles() + game.handle_enemies() + game.handle_fireballs() + --game.handle_blocks() -- blocks being hit. not interactable; we don't care! + game.handle_hammers() + game.handle_misc() + game.handle_tiles() - local coins = R(0x7ED) * 10 + R(0x7EE) + local coins = game.R(0x7ED) * 10 + game.R(0x7EE) local coins_delta = coins - coins_old -- handle wrap-around. if coins_delta < 0 then coins_delta = 100 + coins - coins_old end @@ -964,9 +621,9 @@ local function doit(dummy) local powerup_delta = powerup_old - powerup -- 2 is fire mario. local status_delta = clamp(status - status_old, -1, 1) - local flagpole_bonus = R(0xE) == 4 and cfg.frameskip or 0 + local flagpole_bonus = game.R(0xE) == 4 and cfg.frameskip or 0 --local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus - local score_delta = get_score() - score_old + local score_delta = game.get_score() - score_old if score_delta < 0 then score_delta = 0 end local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus screen_scroll_delta = 0 @@ -978,11 +635,11 @@ local function doit(dummy) --gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F') gui.text(89, 16, ("%+5i"):format(reward), '#FFFFFF', '#0000003F') - if get_state() == 'dead' and state_old ~= 'dead' then - --print("dead. lives remaining:", R(0x75A, 0)) - if R(0x75A, 0) == 0 then prepare_reset() end + if game.get_state() == 'dead' and state_old ~= 'dead' then + --print("dead. lives remaining:", game.R(0x75A, 0)) + if game.R(0x75A, 0) == 0 then prepare_reset() end end - if get_state() == 'lose' then + if game.get_state() == 'lose' then -- this shouldn't happen if we catch the deaths as above. print("ran out of lives.") if not cfg.playback_mode then prepare_reset() end @@ -997,25 +654,25 @@ local function doit(dummy) force_start = ingame_paused and timer == 0 local X = {} - for i, v in ipairs(sprite_input) do insert(X, v / 256) end - for i, v in ipairs(extra_input) do insert(X, v / 256) end + for i, v in ipairs(game.sprite_input) do insert(X, v / 256) end + for i, v in ipairs(game.extra_input) do insert(X, v / 256) end nn.reshape(X, 1, gcfg.input_size) - nn.reshape(tile_input, 1, gcfg.tile_count) + nn.reshape(game.tile_input, 1, gcfg.tile_count) trial_frames = trial_frames + cfg.frameskip - if cfg.enable_network and get_state() == 'playing' or ingame_paused then + if cfg.enable_network and game.get_state() == 'playing' or ingame_paused then total_frames = total_frames + cfg.frameskip - local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input}) + local outputs = network:forward({[nn_x]=X, [nn_tx]=game.tile_input}) local eps = lerp(cfg.eps_start, cfg.eps_stop, total_frames / cfg.eps_frames) if cfg.det_epsilon and random() < eps then local i = floor(random() * #gcfg.jp_lut) + 1 - jp = nn.copy(gcfg.jp_lut[i], jp) + jp = copy(gcfg.jp_lut[i], jp) else local choose = cfg.deterministic and argmax or softchoice local ind = choose(unpack(outputs[nn_z])) - jp = nn.copy(gcfg.jp_lut[ind], jp) + jp = copy(gcfg.jp_lut[ind], jp) end if force_start then @@ -1036,41 +693,41 @@ local function doit(dummy) powerup_old = powerup status_old = status force_start_old = force_start - state_old = get_state() - score_old = get_score() + state_old = game.get_state() + score_old = game.get_score() end init() while true do - gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F') + gui.text(4, 12, game.get_state(), '#FFFFFF', '#0000003F') - while gcfg.bad_states[get_state()] do + while gcfg.bad_states[game.get_state()] do -- mash the start button until we have control. joypad_mash('start') prepare_reset() - advance() - gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F') + game.advance() + gui.text(4, 12, game.get_state(), '#FFFFFF', '#0000003F') - while get_state() == "loading" do advance() end -- kind of a hack. - state_old = get_state() + while game.get_state() == "loading" do game.advance() end -- kind of a hack. + state_old = game.get_state() end if reset then do_reset() end if not cfg.enable_network then -- infinite time cheat. super handy for testing. - if R(0xE) == 8 then - set_timer(667) + if game.R(0xE) == 8 then + game.set_timer(667) poketime = true elseif poketime then poketime = false - set_timer(1) + game.set_timer(1) end -- infinite lives. - W(0x75A, 1) + game.W(0x75A, 1) end -- FIXME: if the game lags then we might miss our frame to change inputs! @@ -1081,5 +738,5 @@ while true do -- jp might still be nil if we're not ingame or we're not playing. if jp ~= nil then joypad.write(1, jp) end - advance() + game.advance() end diff --git a/nn.lua b/nn.lua index b2d0b55..dea421c 100644 --- a/nn.lua +++ b/nn.lua @@ -21,28 +21,11 @@ local unpack = table.unpack or unpack local Base = require("Base") +local util = require("util") + -- hacks local function helpme() print(debug.traceback('helpme', 2):gsub("\n", "\r\n")) end --- general utilities - -local function copy(t, out) -- shallow copy - assert(type(t) == "table") - local out = out or {} - for k, v in pairs(t) do out[k] = v end - return out -end - -local function indexof(t, a) - assert(type(t) == "table") - for k, v in pairs(t) do if v == a then return k end end - return nil -end - -local function contains(t, a) - return indexof(t, a) ~= nil -end - -- math utilities local function prod(x, ...) @@ -198,7 +181,7 @@ end local function cache(bs, shape) if bs == nil then return nil end - local fullshape = copy(shape) + local fullshape = util.copy(shape) insert(fullshape, bs, 1) return zeros(fullshape) end @@ -274,19 +257,19 @@ local function traverse(node_in, node_out, nodes, dummy_mode) if seen_up[node] then local all_parents_added = true for _, parent in ipairs(node.parents) do - if not contains(nodes, parent) then + if not util.contains(nodes, parent) then all_parents_added = false break end end - if not contains(nodes, node) and all_parents_added then + if not util.contains(nodes, node) and all_parents_added then insert(nodes, node) end for _, child in ipairs(node.children) do insert(q, child) end end end - if dummy_mode then remove(nodes, indexof(nodes, node_in)) end + if dummy_mode then remove(nodes, util.indexof(nodes, node_in)) end return nodes end @@ -705,7 +688,7 @@ function Model:forward(inputs) local outputs = {} for i, node in ipairs(self.nodes) do --print(i, node.name) - if contains(self.nodes_in, node) then + if util.contains(self.nodes_in, node) then local X = inputs[node] assert(X ~= nil, ("missing input for node %s"):format(node.name)) assert(X.shape, ("missing shape for node %s"):format(node.name)) @@ -713,7 +696,7 @@ function Model:forward(inputs) else values[node] = node:propagate(values) end - if contains(self.nodes_out, node) then + if util.contains(self.nodes_out, node) then outputs[node] = values[node] end end @@ -805,9 +788,6 @@ function Model:load(fn) end return { - copy = copy, - indexof = indexof, - contains = contains, prod = prod, uniform = uniform, normal = normal, diff --git a/smb.lua b/smb.lua new file mode 100644 index 0000000..43f907d --- /dev/null +++ b/smb.lua @@ -0,0 +1,290 @@ +-- disassembly used for reference: +-- https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM + +local rotation_offsets = { -- FIXME: not all of these are pixel-perfect. + 0, -40, -- 0x00 + 6, -38, + 15, -37, + 22, -32, + 28, -28, + 32, -22, + 37, -14, + 39, -6, + 40, 0, -- 0x08 + 38, 7, + 37, 15, + 33, 23, + 27, 29, + 22, 33, + 14, 37, + 6, 39, + 0, 41, -- 0x10 + -7, 40, + -16, 38, + -22, 34, + -28, 28, + -34, 23, + -38, 16, + -40, 8, + -40, -0, -- 0x18 + -40, -6, + -38, -14, + -34, -22, + -28, -28, + -22, -32, + -16, -36, + -8, -38, +} + +local util = require("util") +local R = memory.readbyteunsigned +local W = memory.writebyte +local function S(addr) return util.signbyte(R(addr)) end + +-- TODO: reinterface to one "input" array visible to main.lua. +local sprite_input = {} +local tile_input = {} +local extra_input = {} + +local overlay = false + +local band = bit.band +local floor = math.floor + +local function get_timer() + return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA) +end + +local function get_score() + return R(0x7DE) * 10000 + + R(0x7DF) * 1000 + + R(0x7E0) * 100 + + R(0x7E1) * 10 + + R(0x7E2) +end + +local function set_timer(time) + W(0x7F8, floor(time / 100)) + W(0x7F9, floor((time / 10) % 10)) + W(0x7FA, floor(time % 10)) +end + +local function mark_sprite(x, y, t) + if x < 0 or x >= 256 or y < 0 or y > 224 then + sprite_input[#sprite_input+1] = 0 + sprite_input[#sprite_input+1] = 0 + sprite_input[#sprite_input+1] = 0 + else + sprite_input[#sprite_input+1] = x + sprite_input[#sprite_input+1] = y + sprite_input[#sprite_input+1] = t + end + if t == 0 then return end + if overlay then + gui.box(x-4, y-4, x+4, y+4) + --gui.text(x-2, y-3, tostring(i), '#FFFFFF', '#00000000') + gui.text(x-13, y-3-9, ("%+04i"):format(t), '#FFFFFF', '#0000003F') + --gui.text(x-5, y-3+9, ("%02X"):format(x), '#FFFFFF', '#0000003F') + end +end + +local function mark_tile(x, y, t) + tile_input[#tile_input+1] = t + if t == 0 then return end + if overlay then + gui.box(x-8, y-8, x+8, y+8) + gui.text(x-5, y-3, ("%02X"):format(t), '#FFFFFF', '#00000000') + end +end + +local function getxy(i, x_addr, y_addr, pageloc_addr, hipos_addr) + local spl_l = R(0x71A) + local spl_r = R(0x71B) + local sx_l = R(0x71C) + local sx_r = R(0x71D) + + local x = R(x_addr + i) + local y = R(y_addr + i) + local sx, sy = x, y + if pageloc_addr ~= nil then + local page = R(pageloc_addr + i) + sx = sx - sx_l - (spl_l - page) * 256 + else + sx = sx - sx_l + end + if hipos_addr ~= nil then + local hipos = S(hipos_addr + i) + sy = sy + (hipos - 1) * 256 + end + + return sx, sy +end + +local function paused() return band(R(0x776), 1) end + +local function get_state() + if R(0xE) == 0xFF then return 'power' end + if R(0x774) > 0 then return 'lagging' end + if R(0x7A2) > 0 then return 'waiting_demo' end + if R(0x717) > 0 then return 'playing_demo' end +-- if R(0x770) == 0xFF then return 'power' end + if paused() ~= 0 then return 'paused' end + if R(0xE) == 0 then return 'world_screen' end +-- if R(0x712) == 1 then return 'deadmusic' end + if R(0x7CA) == 0x94 then return 'dead' end + if R(0xE) == 4 then return 'win_flagpole' end + if R(0xE) == 5 then return 'win_walking' end + if R(0xE) == 6 then return 'lose' end +-- if R(0x770) == 0 then return 'not_playing' end + if R(0x770) == 2 then return 'win_castle' end + if R(0x772) == 2 then return 'no_control' end + if R(0x772) == 3 then return 'playing' end + if R(0x770) == 1 then return 'loading' end + if R(0x770) == 3 then return 'lose' end + return 'unknown' +end + +local function advance() + emu.frameadvance() + while emu.lagged() do emu.frameadvance() end -- skip lag frames. + while R(0x774) > 0 do emu.frameadvance() end -- also lag frames. +end + +local function handle_enemies() + -- enemies, flagpole + for i = 0, 5 do + local x, y = getxy(i, 0x87, 0xCF, 0x6E, 0xB6) + x, y = x + 8, y + 16 + local tid = R(0x16 + i) + local flags = R(0xF + i) + --local offscr = R(0x3D8 + i) + local invisible = tid < 0x10 and flags == 0 + if tid == 0x30 then y = y - 8 end -- flagpole flag + if tid == 0x31 then y = y - 8 end -- castle flag + if tid == 0x16 then x, y = x - 4, y - 12 end -- fireworks + if tid >= 0x24 and tid <= 0x29 then x, y = x + 16, y - 12 end -- moving platforms + if tid == 0x2D then x, y = x, y end -- bowser (TODO: determine head or body) + if tid == 0x15 then x, y = x, y - 12 end -- bowser fire + if tid == 0x32 then x, y = x, y - 8 end -- spring + -- tid == 0x35 -- toad + if tid == 0x1D or tid == 0x1B then -- rotating fire bars + x, y = x - 4, y - 12 + -- this is a mess... gotta find out its rotation and then project. + -- TODO: handle long fire bars too + local rot = R(0xA0 + i) --* 0x100 + R(0x58 + i) + gui.text(x-13, y-3+9, ("%04X"):format(rot), '#FFFFFF', '#0000003F') + local x_off, y_off = rotation_offsets[rot*2+1], rotation_offsets[rot*2+2] + x, y = x + x_off, y + y_off + end + if invisible then + mark_sprite(0, 0, 0) + else + mark_sprite(x, y, tid + 1) + end + end +end + +local function handle_fireballs() + for i = 0, 1 do + local x, y = getxy(i, 0x8D, 0xD5, 0x74, 0xBC) + x, y = x + 4, y + 4 + local state = R(0x24 + i) + local invisible = state == 0 + if invisible then + mark_sprite(0, 0, 0) + else + mark_sprite(x, y, 257) + end + end +end + +local function handle_blocks() + for i = 0, 3 do + local x, y = getxy(i, 0x8F, 0xD7, 0x76, 0xBE) + x, y = x + 8, y + 8 + local state = R(0x26 + i) + local invisible = state == 0 + if invisible then + mark_sprite(0, 0, 0) + else + mark_sprite(x, y, 258) + end + end +end + +local function handle_hammers() + -- hammers, coins, score bonus text... + for i = 0, 8 do + local x, y = getxy(i, 0x93, 0xDB, 0x7A, 0xC2) + x, y = x + 8, y + 8 + local state = R(0x2A + i) + -- skip coin effect states. not interactable; we don't care! + if state ~= 0 and state >= 0x30 then + mark_sprite(x, y, state + 1) + else + mark_sprite(0, 0, 0) + end + end +end + +local function handle_misc() + for i = 0, 0 do + local x, y = getxy(i, 0x9C, 0xE4, 0x83, 0xCB) + x, y = x + 8, y + 8 + local state = R(0x33 + i) + if state ~= 0 then + mark_sprite(x, y, state + 1) + else + mark_sprite(0, 0, 0) + end + end +end + +local function handle_tiles() + --local tile_col = R(0x6A0) + local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16 + local tile_scroll_remainder = R(0x73F) % 16 + extra_input[#extra_input+1] = tile_scroll_remainder + for y = 0, 12 do + for x = 0, 16 do + local col = (x + tile_scroll) % 32 + local t + if col < 16 then + t = R(0x500 + y * 16 + (col % 16)) + else + t = R(0x5D0 + y * 16 + (col % 16)) + end + local sx = x * 16 + 8 - tile_scroll_remainder + local sy = y * 16 + 40 + mark_tile(sx, sy, t) + end + end +end + +return { +-- TODO: don't expose these; provide interfaces for everything needed. +R=R, +W=W, +S=S, +overlay=overlay, + +sprite_input=sprite_input, +tile_input=tile_input, +extra_input=extra_input, + +get_timer=get_timer, +get_score=get_score, +set_timer=set_timer, +mark_sprite=mark_sprite, +mark_tile=mark_tile, +getxy=getxy, +paused=paused, +get_state=get_state, +advance=advance, +handle_enemies=handle_enemies, +handle_fireballs=handle_fireballs, +handle_blocks=handle_blocks, +handle_hammers=handle_hammers, +handle_misc=handle_misc, +handle_tiles=handle_tiles, +} diff --git a/util.lua b/util.lua new file mode 100644 index 0000000..a5d5c7f --- /dev/null +++ b/util.lua @@ -0,0 +1,176 @@ +-- TODO: reorganize function order. + +local assert = assert +local ipairs = ipairs +local log = math.log +local max = math.max +local min = math.min +local pairs = pairs +local random = math.random +local select = select +local sqrt= math.sqrt + +local function signbyte(x) + if x >= 128 then x = 256 - x end + return x +end + +local function boolean_xor(a, b) + if a and b then return false end + if not a and not b then return false end + return true +end + +local _invlog2 = 1 / log(2) +local function log2(x) return log(x) * _invlog2 end + +local function clamp(x, l, u) return min(max(x, l), u) end + +local function lerp(a, b, t) return a + (b - a) * clamp(t, 0, 1) end + +local function argmax(...) + local max_i = 0 + local max_v = -999999999 + for i=1, select("#", ...) do + local v = select(i, ...) + if v > max_v then + max_i = i + max_v = v + end + end + return max_i +end + +local function softchoice(...) + local t = random() + local psum = 0 + for i=1, select("#", ...) do + local p = select(i, ...) + psum = psum + p + if t < psum then + return i + end + end +end + +local function empty(t) + for k, _ in pairs(t) do t[k] = nil end + return t +end + +local function calc_mean_dev(x) + local mean = 0 + for i, v in ipairs(x) do + mean = mean + v / #x + end + + local dev = 0 + for i, v in ipairs(x) do + local delta = v - mean + dev = dev + delta * delta / #x + end + + return mean, sqrt(dev) +end + +local function normalize(x, out) + out = out or x + local mean, dev = calc_mean_dev(x) + if dev <= 0 then dev = 1 end + for i, v in ipairs(x) do out[i] = (v - mean) / dev end + return out +end + +local function normalize_wrt(x, s, out) + out = out or x + local mean, dev = calc_mean_dev(s) + if dev <= 0 then dev = 1 end + for i, v in ipairs(x) do out[i] = (v - mean) / dev end + return out +end + +local function fitness_shaping(rewards) + -- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py + local decreasing = nn.copy(rewards) + sort(decreasing, function(a, b) return a > b end) + local shaped_returns = {} + local lamb = #rewards + + local denom = 0 + for i, v in ipairs(rewards) do + local l = log2(lamb / 2 + 1) + local r = log2(nn.indexof(decreasing, v)) + denom = denom + max(0, l - r) + end + + for i, v in ipairs(rewards) do + local l = log2(lamb / 2 + 1) + local r = log2(nn.indexof(decreasing, v)) + local numer = max(0, l - r) + insert(shaped_returns, numer / denom + 1 / lamb) + end + + return shaped_returns +end + +local function unperturbed_rank(rewards, unperturbed_reward) + -- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py + local nth_place = 1 + for i, v in ipairs(rewards) do + if v > unperturbed_reward then + nth_place = nth_place + 1 + end + end + return nth_place +end + +local function copy(t, out) -- shallow copy + assert(type(t) == "table") + local out = out or {} + for k, v in pairs(t) do out[k] = v end + return out +end + +local function indexof(t, a) + assert(type(t) == "table") + for k, v in pairs(t) do if v == a then return k end end + return nil +end + +local function contains(t, a) + return indexof(t, a) ~= nil +end + +local function argmax2(t) + return t[1] > t[2] +end + +local function rchoice2(t) + return t[1] > random() +end + +local function rbool() + return 0.5 >= random() +end + +return { + signbyte=signbyte, + boolean_xor=boolean_xor, + log2=log2, + clamp=clamp, + lerp=lerp, + argmax=argmax, + softchoice=softchoice, + empty=empty, + calc_mean_dev=calc_mean_dev, + normalize=normalize, + normalize_wrt=normalize_wrt, + fitness_shaping=fitness_shaping, + unperturbed_rank=unperturbed_rank, + copy=copy, + indexof=indexof, + contains=contains, + argmax2=argmax2, + rchoice2=rchoice2, + rbool=rbool, +}