add frameskip

This commit is contained in:
Connor Olding 2017-09-07 21:20:53 +00:00
parent 3d64df0574
commit 3b4e195ae6
2 changed files with 71 additions and 68 deletions

135
main.lua
View file

@ -30,12 +30,13 @@ local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end
local playable_mode = false local playable_mode = false
-- --
local frameskip = 4
-- true greedy epsilon has both deterministic and det_epsilon set. -- true greedy epsilon has both deterministic and det_epsilon set.
local deterministic = true -- use argmax on outputs instead of random sampling. local deterministic = true -- use argmax on outputs instead of random sampling.
local det_epsilon = true -- take random actions with probability eps. local det_epsilon = true -- take random actions with probability eps.
-- using parameters from DQN... sorta. -- using parameters from DQN
local eps_start = 1.0 * 1/60 -- i think this should be * 1/16 for atari ref. local eps_start = 1.0 * frameskip / 64
local eps_stop = 0.1 * 1/60 -- " local eps_stop = 0.1 * eps_start
local eps_frames = 1000000 local eps_frames = 1000000
local learn_start_select = false local learn_start_select = false
-- --
@ -96,6 +97,9 @@ local sprite_input = {}
local tile_input = {} local tile_input = {}
local extra_input = {} local extra_input = {}
local jp
local screen_scroll_delta
local reward local reward
local all_rewards = {} local all_rewards = {}
@ -394,15 +398,6 @@ local function advance()
while R(0x774) > 0 do emu.frameadvance() end -- also lag frames. while R(0x774) > 0 do emu.frameadvance() end -- also lag frames.
end end
while false do
local state = get_state()
if state ~= state_old then
print(emu.framecount(), state)
state_old = state
end
advance()
end
local function handle_enemies() local function handle_enemies()
-- enemies, flagpole -- enemies, flagpole
for i = 0, 5 do for i = 0, 5 do
@ -672,6 +667,8 @@ local function do_reset()
end end
once = true once = true
jp = nil
screen_scroll_delta = 0
emu.frameadvance() -- prevents emulator from quirking up. emu.frameadvance() -- prevents emulator from quirking up.
--print() --print()
@ -697,50 +694,11 @@ init()
local dummy_softmax_values = {0, 0} local dummy_softmax_values = {0, 0}
while true do local function doit(dummy)
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F') screen_scroll_delta = screen_scroll_delta + R(0x775)
if dummy == true then
while bad_states[get_state()] do gui.text(96, 16, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
--gui.text(120, 124, ("%02X"):format(R(0xE)), '#FFFFFF', '#0000003F') return
-- mash the start button until we have control.
-- TODO: learn this too.
--local jp = joypad.read(1)
local jp = {
up = false,
down = false,
left = false,
right = false,
A = false,
B = false,
select = false,
start = emu.framecount() % 2 == 1,
}
joypad.write(1, jp)
reset = true
advance()
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
-- bit of a hack:
while get_state() == "loading" do advance() end
state_old = get_state()
end
if reset then do_reset() end
if not enable_network then
-- infinite time cheat. super handy for testing.
if R(0xE) == 8 then
set_timer(667)
poketime = true
elseif poketime then
poketime = false
set_timer(1)
end
-- infinite lives.
W(0x75A, 1)
end end
empty(sprite_input) empty(sprite_input)
@ -776,14 +734,12 @@ while true do
local powerup_delta = powerup_old - powerup local powerup_delta = powerup_old - powerup
-- 2 is fire mario. -- 2 is fire mario.
local status_delta = clamp(status - status_old, -1, 1) local status_delta = clamp(status - status_old, -1, 1)
local screen_scroll_delta = R(0x775) local flagpole_bonus = R(0xE) == 4 and frameskip or 0
local flagpole_bonus = R(0xE) == 4 and 1 or 0
--local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus --local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
local score_delta = get_score() - score_old local score_delta = get_score() - score_old
if score_delta < 0 then score_delta = 0 end if score_delta < 0 then score_delta = 0 end
local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus
screen_scroll_delta = 0
-- TODO: add ingame score to reward.
if not ingame_paused then reward = reward + reward_delta end if not ingame_paused then reward = reward + reward_delta end
@ -822,7 +778,7 @@ while true do
-- that's cheating! unpause. -- that's cheating! unpause.
force_start = ingame_paused and timer == 0 force_start = ingame_paused and timer == 0
local X = {} -- TODO: cache. local X = {}
for i, v in ipairs(sprite_input) do insert(X, v / 256) end for i, v in ipairs(sprite_input) do insert(X, v / 256) end
for i, v in ipairs(tile_input) do insert(X, v / 256) end for i, v in ipairs(tile_input) do insert(X, v / 256) end
for i, v in ipairs(extra_input) do insert(X, v / 256) end for i, v in ipairs(extra_input) do insert(X, v / 256) end
@ -849,7 +805,7 @@ while true do
learn_start_select and outputs[nn_z[8]] or dummy_softmax_values, learn_start_select and outputs[nn_z[8]] or dummy_softmax_values,
} }
local jp = { jp = {
up = choose(softmaxed[1]), up = choose(softmaxed[1]),
down = choose(softmaxed[2]), down = choose(softmaxed[2]),
left = choose(softmaxed[3]), left = choose(softmaxed[3]),
@ -880,8 +836,6 @@ while true do
select = false, select = false,
} }
end end
joypad.write(1, jp)
end end
coins_old = coins coins_old = coins
@ -890,5 +844,58 @@ while true do
force_start_old = force_start force_start_old = force_start
state_old = get_state() state_old = get_state()
score_old = get_score() score_old = get_score()
end
while true do
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
while bad_states[get_state()] do
--gui.text(120, 124, ("%02X"):format(R(0xE)), '#FFFFFF', '#0000003F')
-- mash the start button until we have control.
-- TODO: learn this too.
local jp_mash = {
up = false,
down = false,
left = false,
right = false,
A = false,
B = false,
select = false,
start = emu.framecount() % 2 == 1,
}
joypad.write(1, jp_mash)
reset = true
advance()
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
-- bit of a hack:
while get_state() == "loading" do advance() end
state_old = get_state()
end
if reset then do_reset() end
if not enable_network then
-- infinite time cheat. super handy for testing.
if R(0xE) == 8 then
set_timer(667)
poketime = true
elseif poketime then
poketime = false
set_timer(1)
end
-- infinite lives.
W(0x75A, 1)
end
local doot = jp == nil or emu.framecount() % frameskip == 0
doit(not doot)
-- jp might still be nil if we're not ingame or we're not playing.
if jp ~= nil then joypad.write(1, jp) end
advance() advance()
end end

4
nn.lua
View file

@ -422,7 +422,6 @@ function Relu:init()
end end
function Relu:reset_cache(bs) function Relu:reset_cache(bs)
print("clearing cache:", self.name, bs)
self.bs = bs self.bs = bs
self.cache = cache(bs, self.shape_out) self.cache = cache(bs, self.shape_out)
@ -455,7 +454,6 @@ function Gelu:init()
end end
function Gelu:reset_cache(bs) function Gelu:reset_cache(bs)
print("clearing cache:", self.name, bs)
self.bs = bs self.bs = bs
self.cache = cache(bs, self.shape_out) self.cache = cache(bs, self.shape_out)
@ -513,7 +511,6 @@ function Dense:make_shape(parent)
end end
function Dense:reset_cache(bs) function Dense:reset_cache(bs)
print("clearing cache:", self.name, bs)
self.bs = bs self.bs = bs
self.cache = cache(bs, self.shape_out) self.cache = cache(bs, self.shape_out)
@ -567,7 +564,6 @@ function Softmax:init()
end end
function Softmax:reset_cache(bs) function Softmax:reset_cache(bs)
print("clearing cache:", self.name, bs)
self.bs = bs self.bs = bs
self.cache = cache(bs, self.shape_out) self.cache = cache(bs, self.shape_out)