From 3b4e195ae666e253a18e4ce9a08c206e14419d99 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Thu, 7 Sep 2017 21:20:53 +0000 Subject: [PATCH] add frameskip --- main.lua | 135 +++++++++++++++++++++++++++++-------------------------- nn.lua | 4 -- 2 files changed, 71 insertions(+), 68 deletions(-) diff --git a/main.lua b/main.lua index 14acbbd..36ceb39 100644 --- a/main.lua +++ b/main.lua @@ -30,12 +30,13 @@ local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end local playable_mode = false -- +local frameskip = 4 -- true greedy epsilon has both deterministic and det_epsilon set. local deterministic = true -- use argmax on outputs instead of random sampling. local det_epsilon = true -- take random actions with probability eps. --- using parameters from DQN... sorta. -local eps_start = 1.0 * 1/60 -- i think this should be * 1/16 for atari ref. -local eps_stop = 0.1 * 1/60 -- " +-- using parameters from DQN +local eps_start = 1.0 * frameskip / 64 +local eps_stop = 0.1 * eps_start local eps_frames = 1000000 local learn_start_select = false -- @@ -96,6 +97,9 @@ local sprite_input = {} local tile_input = {} local extra_input = {} +local jp + +local screen_scroll_delta local reward local all_rewards = {} @@ -394,15 +398,6 @@ local function advance() while R(0x774) > 0 do emu.frameadvance() end -- also lag frames. end -while false do - local state = get_state() - if state ~= state_old then - print(emu.framecount(), state) - state_old = state - end - advance() -end - local function handle_enemies() -- enemies, flagpole for i = 0, 5 do @@ -672,6 +667,8 @@ local function do_reset() end once = true + jp = nil + screen_scroll_delta = 0 emu.frameadvance() -- prevents emulator from quirking up. --print() @@ -697,50 +694,11 @@ init() local dummy_softmax_values = {0, 0} -while true do - gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F') - - while bad_states[get_state()] do - --gui.text(120, 124, ("%02X"):format(R(0xE)), '#FFFFFF', '#0000003F') - -- mash the start button until we have control. - -- TODO: learn this too. - --local jp = joypad.read(1) - local jp = { - up = false, - down = false, - left = false, - right = false, - A = false, - B = false, - select = false, - start = emu.framecount() % 2 == 1, - } - joypad.write(1, jp) - - reset = true - - advance() - gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F') - - -- bit of a hack: - while get_state() == "loading" do advance() end - state_old = get_state() - end - - if reset then do_reset() end - - if not enable_network then - -- infinite time cheat. super handy for testing. - if R(0xE) == 8 then - set_timer(667) - poketime = true - elseif poketime then - poketime = false - set_timer(1) - end - - -- infinite lives. - W(0x75A, 1) +local function doit(dummy) + screen_scroll_delta = screen_scroll_delta + R(0x775) + if dummy == true then + gui.text(96, 16, ("%+4i"):format(reward), '#FFFFFF', '#0000003F') + return end empty(sprite_input) @@ -776,14 +734,12 @@ while true do local powerup_delta = powerup_old - powerup -- 2 is fire mario. local status_delta = clamp(status - status_old, -1, 1) - local screen_scroll_delta = R(0x775) - local flagpole_bonus = R(0xE) == 4 and 1 or 0 + local flagpole_bonus = R(0xE) == 4 and frameskip or 0 --local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus local score_delta = get_score() - score_old if score_delta < 0 then score_delta = 0 end local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus - - -- TODO: add ingame score to reward. + screen_scroll_delta = 0 if not ingame_paused then reward = reward + reward_delta end @@ -822,7 +778,7 @@ while true do -- that's cheating! unpause. force_start = ingame_paused and timer == 0 - local X = {} -- TODO: cache. + local X = {} for i, v in ipairs(sprite_input) do insert(X, v / 256) end for i, v in ipairs(tile_input) do insert(X, v / 256) end for i, v in ipairs(extra_input) do insert(X, v / 256) end @@ -849,7 +805,7 @@ while true do learn_start_select and outputs[nn_z[8]] or dummy_softmax_values, } - local jp = { + jp = { up = choose(softmaxed[1]), down = choose(softmaxed[2]), left = choose(softmaxed[3]), @@ -880,8 +836,6 @@ while true do select = false, } end - - joypad.write(1, jp) end coins_old = coins @@ -890,5 +844,58 @@ while true do force_start_old = force_start state_old = get_state() score_old = get_score() +end + +while true do + gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F') + + while bad_states[get_state()] do + --gui.text(120, 124, ("%02X"):format(R(0xE)), '#FFFFFF', '#0000003F') + -- mash the start button until we have control. + -- TODO: learn this too. + local jp_mash = { + up = false, + down = false, + left = false, + right = false, + A = false, + B = false, + select = false, + start = emu.framecount() % 2 == 1, + } + joypad.write(1, jp_mash) + + reset = true + + advance() + gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F') + + -- bit of a hack: + while get_state() == "loading" do advance() end + state_old = get_state() + end + + if reset then do_reset() end + + if not enable_network then + -- infinite time cheat. super handy for testing. + if R(0xE) == 8 then + set_timer(667) + poketime = true + elseif poketime then + poketime = false + set_timer(1) + end + + -- infinite lives. + W(0x75A, 1) + end + + local doot = jp == nil or emu.framecount() % frameskip == 0 + doit(not doot) + + -- jp might still be nil if we're not ingame or we're not playing. + if jp ~= nil then joypad.write(1, jp) end + advance() end diff --git a/nn.lua b/nn.lua index 2668592..ca75d20 100644 --- a/nn.lua +++ b/nn.lua @@ -422,7 +422,6 @@ function Relu:init() end function Relu:reset_cache(bs) - print("clearing cache:", self.name, bs) self.bs = bs self.cache = cache(bs, self.shape_out) @@ -455,7 +454,6 @@ function Gelu:init() end function Gelu:reset_cache(bs) - print("clearing cache:", self.name, bs) self.bs = bs self.cache = cache(bs, self.shape_out) @@ -513,7 +511,6 @@ function Dense:make_shape(parent) end function Dense:reset_cache(bs) - print("clearing cache:", self.name, bs) self.bs = bs self.cache = cache(bs, self.shape_out) @@ -567,7 +564,6 @@ function Softmax:init() end function Softmax:reset_cache(bs) - print("clearing cache:", self.name, bs) self.bs = bs self.cache = cache(bs, self.shape_out)