diff --git a/main.lua b/main.lua index adf98f2..a823a52 100644 --- a/main.lua +++ b/main.lua @@ -22,6 +22,7 @@ end -- localize some stuff. +local print = print local ipairs = ipairs local pairs = pairs local select = select @@ -387,6 +388,35 @@ local function handle_tiles() end end +local function doreset() + print("resetting in state:", getstate()) + + if once then + savestate.load(startsave) + print("end of trial reward:", reward) + print() + else + savestate.save(startsave) + end + once = true + + -- bit of a hack: + if getstate() == 'loading' then advance() end + reward = 0 + powerup_old = R(0x754) + status_old = R(0x756) + coins_old = R(0x7ED) * 10 + R(0x7EE) + + -- set lives to 0. you only got one shot! + -- unless you get a 1-up, in which case, please continue! + W(0x75A, 0) + + emu.frameadvance() -- prevents emulator from quirking up. + + reset = false + if network ~= nil then network:reset() end -- FIXME: hack +end + while true do while fuckstates[getstate()] do --gui.text(120, 124, ("%02X"):format(subroutine()), '#FFFFFF', '#0000003F') @@ -415,34 +445,7 @@ while true do state_old = getstate() end - if reset then - print("resetting in state:", getstate()) - - if once then - savestate.load(startsave) - print("end of trial reward:", reward) - print() - else - savestate.save(startsave) - end - once = true - - -- bit of a hack: - if getstate() == 'loading' then advance() end - reward = 0 - powerup_old = R(0x754) - status_old = R(0x756) - coins_old = R(0x7ED) * 10 + R(0x7EE) - - -- set lives to 0. you only got one shot! - -- unless you get a 1-up, in which case, please continue! - W(0x75A, 0) - - emu.frameadvance() -- prevents emulator from quirking up. - - reset = false - if network ~= nil then network:reset() end -- FIXME: hack - end + if reset then doreset() end if not enable_network then -- infinite time cheat. super handy for testing. diff --git a/nn.lua b/nn.lua index 26c5a49..7eced6a 100644 --- a/nn.lua +++ b/nn.lua @@ -1,3 +1,4 @@ +local print = print local tostring = tostring local ipairs = ipairs local pairs = pairs @@ -23,8 +24,8 @@ local function contains(t, a) return false end -local function normal() - return sqrt(-2 * log(uniform() + 1e-8) + 1e-8) * cos(2 * pi * uniform()) / 2 +local function normal() -- box muller + return sqrt(-2 * log(uniform() + 1e-8) + 1e-8) * cos(2 * pi * uniform()) end local function zeros(n, out)