From e2d29352c49cf829d218649bf5dbf3841ecb7cef Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Thu, 29 Jun 2017 04:51:02 +0000 Subject: [PATCH] basic learning --- main.lua | 359 ++++++++++++++++++++++++++++++++++++++++--------------- nn.lua | 216 +++++++++++++++++++-------------- 2 files changed, 391 insertions(+), 184 deletions(-) diff --git a/main.lua b/main.lua index a823a52..1c4080f 100644 --- a/main.lua +++ b/main.lua @@ -20,6 +20,63 @@ local function globalize(t) end end +-- configuration and globals. + +math.randomseed(10) + +local learning_rate = 1e-2 +local deviation = 2e-2 + +local epoch_i = 0 +local epoch_trials = 12 +local base_params +local trial_i = 0 +local trial_noise = {} +local trial_rewards = {} +local trials_remaining = 0 + +local force_start = false +local force_start_old = false + +local enable_overlay = false +local enable_network = true + +local startsave = savestate.create(1) + +local poketime = false + +local sprite_input = {} +local tile_input = {} + +local reward +local powerup_old +local status_old +local coins_old + +local once = false +local reset = true + +local ok_routines = { + [0x4] = true, -- sliding down flagpole + [0x5] = true, -- end of level auto-walk + [0x7] = true, -- start of level auto-walk + [0x8] = true, -- normal (in control) + [0x9] = true, -- acquiring mushroom + [0xA] = true, -- losing big mario + [0xB] = true, -- uhh + [0xC] = true, -- acquiring fireflower +} + +local bad_states = { + power = true, + waiting_demo = true, + playing_demo = true, + unknown = true, + lose = true, +} + +local state_old = '' + -- localize some stuff. local print = print @@ -51,6 +108,14 @@ local arshift = bit.arshift local rol = bit.rol local ror = bit.ror +-- utilities. + +local function boolean_xor(a, b) + if a and b then return false end + if not a and not b then return false end + return true +end + local function clamp(x, l, u) return min(max(x, l), u) end local function argmax(...) @@ -70,6 +135,35 @@ local function argmax2(t) return t[1] > t[2] end +local function empty(t) + for k, _ in pairs(t) do t[k] = nil end + return t +end + +local function calc_mean_dev(x) + local mean = 0 + for i, v in ipairs(x) do + mean = mean + v / #x + end + + local dev = 0 + for i, v in ipairs(x) do + local delta = v - mean + dev = dev + delta * delta / #x + end + + return mean, dev +end + +local function normalize(x, out) + out = out or x + local mean, dev = calc_mean_dev(x) + if dev <= 0 then dev = 1 end + local devs = sqrt(dev) + for i, v in ipairs(x) do out[i] = (v - mean) / devs end + return out +end + -- game-agnostic stuff (i.e. the network itself) package.loaded['nn'] = nil -- DEBUG @@ -100,9 +194,6 @@ end -- and here we go with the game stuff. -local enable_overlay = false -local enable_network = true - --[[ https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM --]] @@ -142,14 +233,17 @@ local rotation_offsets = { -- FIXME: not all of these are pixel-perfect. -8, -38, } -local startsave = savestate.create(1) +local function get_timer() + return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA) +end -local poketime = false +local function set_timer(time) + W(0x7F8, floor(time / 100)) + W(0x7F9, floor((time / 10) % 10)) + W(0x7FA, floor(time % 10)) +end -local sprite_input = {} -local tile_input = {} - -local function shitsprite(x, y, t) +local function mark_sprite(x, y, t) if x < 0 or x >= 256 or y < 0 or y > 224 then sprite_input[#sprite_input+1] = 0 sprite_input[#sprite_input+1] = 0 @@ -168,7 +262,7 @@ local function shitsprite(x, y, t) end end -local function shittile(x, y, t) +local function mark_tile(x, y, t) tile_input[#tile_input+1] = t if t == 0 then return end if enable_overlay then @@ -200,23 +294,9 @@ local function getxy(i, x_addr, y_addr, pageloc_addr, hipos_addr) return sx, sy end -local reward -local powerup_old -local status_old -local coins_old - -local once = false -local reset = true - -emu.poweron() -emu.unpause() -emu.speedmode("normal") - -local function opermode() return R(0x770) end local function paused() return band(R(0x776), 1) end -local function subroutine() return R(0xE) end -local function getstate() +local function get_state() if R(0xE) == 0xFF then return 'power' end if R(0x774) > 0 then return 'lagging' end if R(0x7A2) > 0 then return 'waiting_demo' end @@ -238,35 +318,14 @@ local function getstate() return 'unknown' end -local ok_routines = { - [0x4] = true, -- sliding down flagpole - [0x5] = true, -- end of level auto-walk - [0x7] = true, -- start of level auto-walk - [0x8] = true, -- normal (in control) - [0x9] = true, -- acquiring mushroom - [0xA] = true, -- losing big mario - [0xB] = true, -- uhh - [0xC] = true, -- acquiring fireflower -} - -local fuckstates = { - power = true, - waiting_demo = true, - playing_demo = true, - unknown = true, - lose = true, - paused = true, -} - local function advance() emu.frameadvance() while emu.lagged() do emu.frameadvance() end -- skip lag frames. while R(0x774) > 0 do emu.frameadvance() end -- also lag frames. end -local state_old = '' while false do - local state = getstate() + local state = get_state() if state ~= state_old then print(emu.framecount(), state) state_old = state @@ -301,9 +360,9 @@ local function handle_enemies() x, y = x + x_off, y + y_off end if invisible then - shitsprite(0, 0, 0) + mark_sprite(0, 0, 0) else - shitsprite(x, y, tid + 1) + mark_sprite(x, y, tid + 1) end end end @@ -316,9 +375,9 @@ local function handle_fireballs() local state = R(0x24 + i) local invisible = state == 0 if invisible then - shitsprite(0, 0, 0) + mark_sprite(0, 0, 0) else - shitsprite(x, y, 257) + mark_sprite(x, y, 257) end end end @@ -330,9 +389,9 @@ local function handle_blocks() local state = R(0x26 + i) local invisible = state == 0 if invisible then - shitsprite(0, 0, 0) + mark_sprite(0, 0, 0) else - shitsprite(x, y, 258) + mark_sprite(x, y, 258) end end end @@ -347,9 +406,9 @@ local function handle_hammers() if state ~= 0 and state >= 0x30 then - shitsprite(x, y, state + 1) + mark_sprite(x, y, state + 1) else - shitsprite(0, 0, 0) + mark_sprite(0, 0, 0) end end end @@ -360,9 +419,9 @@ local function handle_misc() x, y = x + 8, y + 8 local state = R(0x33 + i) if state ~= 0 then - shitsprite(x, y, state + 1) + mark_sprite(x, y, state + 1) else - shitsprite(0, 0, 0) + mark_sprite(0, 0, 0) end end end @@ -383,25 +442,104 @@ local function handle_tiles() end local sx = x * 16 + 8 - tile_scroll_remainder local sy = y * 16 + 40 - shittile(sx, sy, t) + mark_tile(sx, sy, t) end end end -local function doreset() - print("resetting in state:", getstate()) +local function learn_from_epoch() + print() + print('rewards:', trial_rewards) + normalize(trial_rewards) + print('normalized:', trial_rewards) + + local reward_mean, reward_dev = calc_mean_dev(trial_rewards) + + local step = nn.zeros(#base_params) + for i = 1, epoch_trials do + local reward = trial_rewards[i] + local noise = trial_noise[i] + for j, v in ipairs(noise) do + step[j] = step[j] + reward * v + end + end + + local magnitude = learning_rate / deviation + print('stepping with magnitude', magnitude) + -- throw the division from the averaging in there too. + local altogether = magnitude / epoch_trials + for i, v in ipairs(step) do + step[i] = altogether * v + end + + local step_mean, step_dev = calc_mean_dev(step) + print("step mean:", step_mean) + print("step stddev:", step_dev) + + if step_dev > 1e-8 then + for i, v in ipairs(base_params) do + base_params[i] = v + step[i] + end + else + -- we didn't get anywhere. step in a random direction. + local noise = trial_noise[1] + for i, v in ipairs(base_params) do + base_params[i] = v + magnitude * noise[i] + end + end + + network:distribute(base_params) + + network:save() + + print() +end + +local function prepare_epoch() + print('preparing epoch '..tostring(epoch_i)..'. this might take a while.') + base_params = network:collect() + empty(trial_noise) + empty(trial_rewards) + for i = 1, epoch_trials do + local noise = nn.zeros(#base_params) + for j = 1, #base_params do noise[j] = nn.normal() end + trial_noise[i] = noise + end + trial_i = 0 +end + +local function load_next_trial() + trial_i = trial_i + 1 + print('loading trial', trial_i) + local W = nn.copy(base_params) + local noise = trial_noise[trial_i] + for i, v in ipairs(base_params) do + W[i] = v + deviation * noise[i] + end + network:distribute(W) +end + +local function do_reset() + print("resetting in state:", get_state()) + + if trial_i > 0 then trial_rewards[trial_i] = reward end + + if epoch_i == 0 or trial_i == epoch_trials then + if epoch_i > 0 then learn_from_epoch() else network:reset() end + epoch_i = epoch_i + 1 + prepare_epoch() + end if once then savestate.load(startsave) print("end of trial reward:", reward) - print() else savestate.save(startsave) end once = true -- bit of a hack: - if getstate() == 'loading' then advance() end + if get_state() == 'loading' then advance() end reward = 0 powerup_old = R(0x754) status_old = R(0x756) @@ -413,13 +551,29 @@ local function doreset() emu.frameadvance() -- prevents emulator from quirking up. + print() + load_next_trial() + reset = false - if network ~= nil then network:reset() end -- FIXME: hack end +local function init() + network = make_network(279, 8) + network:reset() + print("parameters:", network.n_param) + + emu.poweron() + emu.unpause() + emu.speedmode("normal") +end + +init() + while true do - while fuckstates[getstate()] do - --gui.text(120, 124, ("%02X"):format(subroutine()), '#FFFFFF', '#0000003F') + gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F') + + while bad_states[get_state()] do + --gui.text(120, 124, ("%02X"):format(R(0xE)), '#FFFFFF', '#0000003F') -- mash the start button until we have control. -- TODO: learn this too. --local jp = joypad.read(1) @@ -437,44 +591,39 @@ while true do reset = true - gui.text(4, 12, getstate(), '#FFFFFF', '#0000003F') advance() + gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F') -- bit of a hack: - while getstate() == "loading" do advance() end - state_old = getstate() + while get_state() == "loading" do advance() end + state_old = get_state() end - if reset then doreset() end + if reset then do_reset() end if not enable_network then -- infinite time cheat. super handy for testing. if R(0xE) == 8 then - W(0x7F8, 9) - W(0x7F9, 9) - W(0x7FA, 10) + set_timer(667) poketime = true elseif poketime then poketime = false - W(0x7F8, 0) - W(0x7F9, 0) - W(0x7FA, 1) + set_timer(1) end -- infinite lives. W(0x75A, 1) end - -- empty input lists without creating a new table. - for k, v in pairs(sprite_input) do sprite_input[k] = nil end - for k, v in pairs(tile_input) do tile_input[k] = nil end + empty(sprite_input) + empty(tile_input) -- player - -- TODO: add check if mario is playable. + -- TODO: check if mario is playable. local x, y = getxy(0, 0x86, 0xCE, 0x6D, 0xB5) local powerup = R(0x754) local status = R(0x756) - shitsprite(x + 8, y + 24, -powerup - 1) + mark_sprite(x + 8, y + 24, -powerup - 1) handle_enemies() handle_fireballs() @@ -484,6 +633,8 @@ while true do handle_misc() handle_tiles() + local ingame_paused = get_state() == "paused" + local coins = R(0x7ED) * 10 + R(0x7EE) local coins_delta = coins - coins_old -- handle wrap-around. @@ -493,45 +644,48 @@ while true do -- 2 is fire mario. local status_delta = clamp(status - status_old, -1, 1) local screen_scroll_delta = R(0x775) - local flagpole_bonus = subroutine() == 4 and 1 or 0 + local flagpole_bonus = R(0xE) == 4 and 1 or 0 local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus - reward = reward + reward_delta + if not ingame_paused then reward = reward + reward_delta end --gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F') --gui.text(4, 22, ("%02X"):format(#tile_input), '#FFFFFF', '#0000003F') gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F') gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F') - if getstate() == 'dead' and state_old ~= 'dead' then + if get_state() == 'dead' and state_old ~= 'dead' then print("dead. lives remaining:", R(0x75A, 0)) if R(0x75A, 0) == 0 then reset = true end end - if getstate() == 'lose' then + if get_state() == 'lose' then print("ran out of lives.") reset = true end + -- lose a point for every frame paused. + if ingame_paused then reward = reward - 1 end + -- every few frames mario stands still, forcibly decrease the timer. + -- this includes having the game paused. + -- TODO: more robust. doesn't detect moonwalking against a wall. + local timer = get_timer() local timer_loser = 1/5 - if math.random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then - local timer = R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA) + if ingame_paused or math.random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then timer = clamp(timer - 1, 0, 400) - W(0x7F8, floor(timer / 100)) - W(0x7F9, floor((timer / 10) % 10)) - W(0x7FA, floor(timer % 10)) + set_timer(timer) end + -- if we've run out of time while the game is paused... + -- that's cheating! unpause. + force_start = ingame_paused and timer == 0 + local X = {} -- TODO: cache. for i, v in ipairs(sprite_input) do insert(X, v / 256) end for i, v in ipairs(tile_input) do insert(X, v / 256) end + --error(#X) - if network == nil then - network = make_network(#X, 8); network:reset() - print("parameters:", network.n_param) - end - - if enable_network and getstate() == 'playing' then + if enable_network and get_state() == 'playing' or ingame_paused then local outputs = network:forward(X) local softmaxed = { outputs[nn_z[1]], @@ -553,12 +707,25 @@ while true do start = argmax2(softmaxed[7]), select = argmax2(softmaxed[8]), } + if force_start then + jp = { + up = false, + down = false, + left = false, + right = false, + A = false, + B = false, + start = force_start_old, + select = false, + } + end joypad.write(1, jp) end coins_old = coins powerup_old = powerup status_old = status - state_old = getstate() + force_start_old = force_start + state_old = get_state() advance() end diff --git a/nn.lua b/nn.lua index 7eced6a..f171067 100644 --- a/nn.lua +++ b/nn.lua @@ -13,6 +13,7 @@ local cos = math.cos local sin = math.sin local insert = table.insert local remove = table.remove +local open = io.open local bor = bit.bor @@ -24,6 +25,17 @@ local function contains(t, a) return false end +local function prod(x, ...) + if type(x) == "table" then + return prod(unpack(x)) + end + local ret = x + for i = 1, select("#", ...) do + ret = ret * select(i, ...) + end + return ret +end + local function normal() -- box muller return sqrt(-2 * log(uniform() + 1e-8) + 1e-8) * cos(2 * pi * uniform()) end @@ -58,28 +70,53 @@ local function copy(t) -- shallow copy end local function allocate(t, out, init) - -- FIXME: this code is fucking disgusting. - out = out or {} - assert(type(out) == "table", type(out)) - if type(t) == "number" then - local size = t - if init ~= nil then - return init(zeros(size, out)) - else - return zeros(size, out) + local size = t + if init ~= nil then + return init(zeros(size, out)) + else + return zeros(size, out) + end +end + +local function levelorder(field, node_in, nodes) + -- horribly inefficient. + nodes = nodes or {} + local q = {node_in} + while #q > 0 do + local node = q[1] + remove(q, 1) + insert(nodes, node) + for _, child in ipairs(node[field]) do + q[#q+1] = child end end - local topsize = t[1] - t = copy(t) - remove(t, 1) - if #t == 1 then t = t[1] end - for i = 1, topsize do - local res = allocate(t, nil, init) - assert(res ~= nil) - insert(out, res) + return nodes +end + +local function traverse(node_in, node_out, nodes) + nodes = nodes or {} + local down = levelorder('children', node_in, {}) + local up = levelorder('parents', node_out, {}) + local seen = {} + for _, node in ipairs(up) do + seen[node] = bor(seen[node] or 0, 1) end - return out + for _, node in ipairs(down) do + seen[node] = bor(seen[node] or 0, 2) + if seen[node] == 3 then + insert(nodes, node) + end + end + return nodes +end + +local function traverse_all(nodes_in, nodes_out, nodes) + local all_in = {children={}} + local all_out = {parents={}} + for _, node in ipairs(nodes_in) do insert(all_in.children, node) end + for _, node in ipairs(nodes_out) do insert(all_out.parents, node) end + return traverse(all_in, all_out, nodes or {}) end local Weights = Base:extend() @@ -95,24 +132,13 @@ function Weights:init(weight_init) end function Weights:allocate(fan_in, fan_out) + self.size = prod(self.shape) return allocate(self.size, self, function(t) --print('initializing weights of size', self.size, 'with fans', fan_in, fan_out) return self.weight_init(t, fan_in, fan_out) end) end ---[[ -local w = Weights(init_he_uniform) -w.size = {16, 16} -w:allocate(16, 16) -print(w) -do return end - -local w = zeros(16) -for i = 1, #w do w[i] = normal() * 1920 / 2560 end -print(w) ---]] - local counter = {} function Layer:init(name) assert(type(name) == "string") @@ -152,18 +178,7 @@ function Layer:_new_weights(init) return w end -local function prod(x, ...) - if type(x) == "table" then - return prod(unpack(x)) - end - local ret = x - for i = 1, select("#", ...) do - ret = ret * select(i, ...) - end - return ret -end - -function Layer:getsize() +function Layer:get_size() local size = 0 for i, w in ipairs(self.weights) do size = size + prod(w.size) end return size @@ -237,8 +252,8 @@ end function Dense:make_shape(parent) self.size_in = parent.size_out - self.coeffs.size = {self.dim, self.size_in} - self.biases.size = self.dim + self.coeffs.shape = {self.size_in, self.dim} + self.biases.shape = self.dim end function Dense:forward(X) @@ -246,11 +261,11 @@ function Dense:forward(X) self.cache = self.cache or zeros(self.size_out) local Y = self.cache - for i = 1, #self.coeffs do + for i = 1, self.dim do local res = 0 - local c = self.coeffs[i] + local c = (i - 1) * #X for j = 1, #X do - res = res + X[j] * c[j] + res = res + X[j] * self.coeffs[c + j] end Y[i] = res + self.biases[i] end @@ -281,46 +296,6 @@ function Softmax:forward(X) return Y end -local function levelorder(field, node_in, nodes) - -- horribly inefficient. - nodes = nodes or {} - local q = {node_in} - while #q > 0 do - local node = q[1] - remove(q, 1) - insert(nodes, node) - for _, child in ipairs(node[field]) do - q[#q+1] = child - end - end - return nodes -end - -local function traverse(node_in, node_out, nodes) - nodes = nodes or {} - local down = levelorder('children', node_in, {}) - local up = levelorder('parents', node_out, {}) - local seen = {} - for _, node in ipairs(up) do - seen[node] = bor(seen[node] or 0, 1) - end - for _, node in ipairs(down) do - seen[node] = bor(seen[node] or 0, 2) - if seen[node] == 3 then - insert(nodes, node) - end - end - return nodes -end - -local function traverse_all(nodes_in, nodes_out, nodes) - local all_in = {children={}} - local all_out = {parents={}} - for _, node in ipairs(nodes_in) do insert(all_in.children, node) end - for _, node in ipairs(nodes_out) do insert(all_out.parents, node) end - return traverse(all_in, all_out, nodes or {}) -end - function Model:init(nodes_in, nodes_out) assert(#nodes_in > 0, #nodes_in) assert(#nodes_out > 0, #nodes_out) @@ -338,7 +313,7 @@ function Model:reset() self.n_param = 0 for _, node in ipairs(self.nodes) do node:init_weights() - self.n_param = self.n_param + node:getsize() + self.n_param = self.n_param + node:get_size() end end @@ -359,10 +334,75 @@ function Model:forward(X) return outputs end +function Model:collect() + -- return a flat array of all the weights in the graph. + -- if Lua had slices, we wouldn't need this. future library idea? + assert(self.n_param >= 0, self.n_param) + local W = zeros(self.n_param) + local i = 0 + for _, node in ipairs(self.nodes) do + for _, w in ipairs(node.weights) do + for j, v in ipairs(w) do + W[i+j] = v + end + i = i + #w + end + end + return W +end + +function Model:distribute(W) + -- inverse operation of collect(). + local i = 0 + for _, node in ipairs(self.nodes) do + for _, w in ipairs(node.weights) do + for j, v in ipairs(w) do + w[j] = W[i+j] + end + i = i + #w + end + end +end + +function Model:save(fn) + local fn = fn or 'network.txt' + local f = open(fn, 'w') + if f == nil then error("Failed to save network to file "..fn) end + local W = self:collect() + for i, v in ipairs(W) do + f:write(v) + f:write('\n') + end + f:close() +end + +function Model:load(fn) + local fn = fn or 'network.txt' + local f = open(fn, 'r') + if f == nil then + error("Failed to load network from file "..fn) + end + + local W = zeros(self.n_param) + local i = 0 + for line in f:lines() do + i = i + 1 + local n = tonumber(line) + if n == nil then + error("Failed reading line "..tostring(i).." of file "..fn) + end + W[i] = n + end + f:close() + + self:distribute(W) +end + return { uniform = uniform, normal = normal, + copy = copy, zeros = zeros, init_zeros = init_zeros, init_he_uniform = init_he_uniform,