add frameskip
This commit is contained in:
parent
3d64df0574
commit
3b4e195ae6
2 changed files with 71 additions and 68 deletions
135
main.lua
135
main.lua
|
@ -30,12 +30,13 @@ local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end
|
|||
|
||||
local playable_mode = false
|
||||
--
|
||||
local frameskip = 4
|
||||
-- true greedy epsilon has both deterministic and det_epsilon set.
|
||||
local deterministic = true -- use argmax on outputs instead of random sampling.
|
||||
local det_epsilon = true -- take random actions with probability eps.
|
||||
-- using parameters from DQN... sorta.
|
||||
local eps_start = 1.0 * 1/60 -- i think this should be * 1/16 for atari ref.
|
||||
local eps_stop = 0.1 * 1/60 -- "
|
||||
-- using parameters from DQN
|
||||
local eps_start = 1.0 * frameskip / 64
|
||||
local eps_stop = 0.1 * eps_start
|
||||
local eps_frames = 1000000
|
||||
local learn_start_select = false
|
||||
--
|
||||
|
@ -96,6 +97,9 @@ local sprite_input = {}
|
|||
local tile_input = {}
|
||||
local extra_input = {}
|
||||
|
||||
local jp
|
||||
|
||||
local screen_scroll_delta
|
||||
local reward
|
||||
local all_rewards = {}
|
||||
|
||||
|
@ -394,15 +398,6 @@ local function advance()
|
|||
while R(0x774) > 0 do emu.frameadvance() end -- also lag frames.
|
||||
end
|
||||
|
||||
while false do
|
||||
local state = get_state()
|
||||
if state ~= state_old then
|
||||
print(emu.framecount(), state)
|
||||
state_old = state
|
||||
end
|
||||
advance()
|
||||
end
|
||||
|
||||
local function handle_enemies()
|
||||
-- enemies, flagpole
|
||||
for i = 0, 5 do
|
||||
|
@ -672,6 +667,8 @@ local function do_reset()
|
|||
end
|
||||
once = true
|
||||
|
||||
jp = nil
|
||||
screen_scroll_delta = 0
|
||||
emu.frameadvance() -- prevents emulator from quirking up.
|
||||
|
||||
--print()
|
||||
|
@ -697,50 +694,11 @@ init()
|
|||
|
||||
local dummy_softmax_values = {0, 0}
|
||||
|
||||
while true do
|
||||
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
|
||||
|
||||
while bad_states[get_state()] do
|
||||
--gui.text(120, 124, ("%02X"):format(R(0xE)), '#FFFFFF', '#0000003F')
|
||||
-- mash the start button until we have control.
|
||||
-- TODO: learn this too.
|
||||
--local jp = joypad.read(1)
|
||||
local jp = {
|
||||
up = false,
|
||||
down = false,
|
||||
left = false,
|
||||
right = false,
|
||||
A = false,
|
||||
B = false,
|
||||
select = false,
|
||||
start = emu.framecount() % 2 == 1,
|
||||
}
|
||||
joypad.write(1, jp)
|
||||
|
||||
reset = true
|
||||
|
||||
advance()
|
||||
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
|
||||
|
||||
-- bit of a hack:
|
||||
while get_state() == "loading" do advance() end
|
||||
state_old = get_state()
|
||||
end
|
||||
|
||||
if reset then do_reset() end
|
||||
|
||||
if not enable_network then
|
||||
-- infinite time cheat. super handy for testing.
|
||||
if R(0xE) == 8 then
|
||||
set_timer(667)
|
||||
poketime = true
|
||||
elseif poketime then
|
||||
poketime = false
|
||||
set_timer(1)
|
||||
end
|
||||
|
||||
-- infinite lives.
|
||||
W(0x75A, 1)
|
||||
local function doit(dummy)
|
||||
screen_scroll_delta = screen_scroll_delta + R(0x775)
|
||||
if dummy == true then
|
||||
gui.text(96, 16, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
|
||||
return
|
||||
end
|
||||
|
||||
empty(sprite_input)
|
||||
|
@ -776,14 +734,12 @@ while true do
|
|||
local powerup_delta = powerup_old - powerup
|
||||
-- 2 is fire mario.
|
||||
local status_delta = clamp(status - status_old, -1, 1)
|
||||
local screen_scroll_delta = R(0x775)
|
||||
local flagpole_bonus = R(0xE) == 4 and 1 or 0
|
||||
local flagpole_bonus = R(0xE) == 4 and frameskip or 0
|
||||
--local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
|
||||
local score_delta = get_score() - score_old
|
||||
if score_delta < 0 then score_delta = 0 end
|
||||
local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus
|
||||
|
||||
-- TODO: add ingame score to reward.
|
||||
screen_scroll_delta = 0
|
||||
|
||||
if not ingame_paused then reward = reward + reward_delta end
|
||||
|
||||
|
@ -822,7 +778,7 @@ while true do
|
|||
-- that's cheating! unpause.
|
||||
force_start = ingame_paused and timer == 0
|
||||
|
||||
local X = {} -- TODO: cache.
|
||||
local X = {}
|
||||
for i, v in ipairs(sprite_input) do insert(X, v / 256) end
|
||||
for i, v in ipairs(tile_input) do insert(X, v / 256) end
|
||||
for i, v in ipairs(extra_input) do insert(X, v / 256) end
|
||||
|
@ -849,7 +805,7 @@ while true do
|
|||
learn_start_select and outputs[nn_z[8]] or dummy_softmax_values,
|
||||
}
|
||||
|
||||
local jp = {
|
||||
jp = {
|
||||
up = choose(softmaxed[1]),
|
||||
down = choose(softmaxed[2]),
|
||||
left = choose(softmaxed[3]),
|
||||
|
@ -880,8 +836,6 @@ while true do
|
|||
select = false,
|
||||
}
|
||||
end
|
||||
|
||||
joypad.write(1, jp)
|
||||
end
|
||||
|
||||
coins_old = coins
|
||||
|
@ -890,5 +844,58 @@ while true do
|
|||
force_start_old = force_start
|
||||
state_old = get_state()
|
||||
score_old = get_score()
|
||||
end
|
||||
|
||||
while true do
|
||||
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
|
||||
|
||||
while bad_states[get_state()] do
|
||||
--gui.text(120, 124, ("%02X"):format(R(0xE)), '#FFFFFF', '#0000003F')
|
||||
-- mash the start button until we have control.
|
||||
-- TODO: learn this too.
|
||||
local jp_mash = {
|
||||
up = false,
|
||||
down = false,
|
||||
left = false,
|
||||
right = false,
|
||||
A = false,
|
||||
B = false,
|
||||
select = false,
|
||||
start = emu.framecount() % 2 == 1,
|
||||
}
|
||||
joypad.write(1, jp_mash)
|
||||
|
||||
reset = true
|
||||
|
||||
advance()
|
||||
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
|
||||
|
||||
-- bit of a hack:
|
||||
while get_state() == "loading" do advance() end
|
||||
state_old = get_state()
|
||||
end
|
||||
|
||||
if reset then do_reset() end
|
||||
|
||||
if not enable_network then
|
||||
-- infinite time cheat. super handy for testing.
|
||||
if R(0xE) == 8 then
|
||||
set_timer(667)
|
||||
poketime = true
|
||||
elseif poketime then
|
||||
poketime = false
|
||||
set_timer(1)
|
||||
end
|
||||
|
||||
-- infinite lives.
|
||||
W(0x75A, 1)
|
||||
end
|
||||
|
||||
local doot = jp == nil or emu.framecount() % frameskip == 0
|
||||
doit(not doot)
|
||||
|
||||
-- jp might still be nil if we're not ingame or we're not playing.
|
||||
if jp ~= nil then joypad.write(1, jp) end
|
||||
|
||||
advance()
|
||||
end
|
||||
|
|
4
nn.lua
4
nn.lua
|
@ -422,7 +422,6 @@ function Relu:init()
|
|||
end
|
||||
|
||||
function Relu:reset_cache(bs)
|
||||
print("clearing cache:", self.name, bs)
|
||||
self.bs = bs
|
||||
|
||||
self.cache = cache(bs, self.shape_out)
|
||||
|
@ -455,7 +454,6 @@ function Gelu:init()
|
|||
end
|
||||
|
||||
function Gelu:reset_cache(bs)
|
||||
print("clearing cache:", self.name, bs)
|
||||
self.bs = bs
|
||||
|
||||
self.cache = cache(bs, self.shape_out)
|
||||
|
@ -513,7 +511,6 @@ function Dense:make_shape(parent)
|
|||
end
|
||||
|
||||
function Dense:reset_cache(bs)
|
||||
print("clearing cache:", self.name, bs)
|
||||
self.bs = bs
|
||||
|
||||
self.cache = cache(bs, self.shape_out)
|
||||
|
@ -567,7 +564,6 @@ function Softmax:init()
|
|||
end
|
||||
|
||||
function Softmax:reset_cache(bs)
|
||||
print("clearing cache:", self.name, bs)
|
||||
self.bs = bs
|
||||
|
||||
self.cache = cache(bs, self.shape_out)
|
||||
|
|
Loading…
Reference in a new issue