This commit is contained in:
Connor Olding 2017-09-09 19:37:01 +00:00
parent d384635000
commit 01d7e5e230

View file

@ -35,7 +35,6 @@ local det_epsilon = true -- take random actions with probability eps.
local eps_start = 1.0 * frameskip / 64
local eps_stop = 0.1 * eps_start
local eps_frames = 2000000
local learn_start_select = false
--
local epoch_trials = 40
local negate_trials = true -- try pairs of normal and negated noise directions.
@ -224,7 +223,7 @@ local function rchoice2(t)
return t[1] > random()
end
local function rbool(t)
local function rbool()
return 0.5 >= random()
end
@ -273,7 +272,7 @@ local nn = require("nn")
local network
local nn_x, nn_tx, nn_ty, nn_y, nn_z
local function make_network(input_size, buttons)
local function make_network(input_size)
nn_x = nn.Input({input_size})
nn_tx = nn.Input({tile_count})
nn_ty = nn_tx:feed(nn.Embed(256, 2))
@ -288,6 +287,11 @@ local function make_network(input_size, buttons)
nn_y = nn_y:feed(nn.Relu())
nn_y = nn_y:feed(nn.Dense(48))
nn_y = nn_y:feed(nn.Relu())
else
nn_y = nn_y:feed(nn.Dense(128))
nn_y = nn_y:feed(nn.Gelu())
nn_y = nn_y:feed(nn.Dense(128))
nn_y = nn_y:feed(nn.Gelu())
end
nn_z = nn_y
@ -471,7 +475,6 @@ local function handle_enemies()
end
local function handle_fireballs()
-- fireballs
for i = 0, 1 do
local x, y = getxy(i, 0x8D, 0xD5, 0x74, 0xBC)
x, y = x + 4, y + 4
@ -681,17 +684,14 @@ local function do_reset()
prepare_epoch()
end
-- bit of a hack:
if get_state() == 'loading' then advance() end
if get_state() == 'loading' then advance() end -- kind of a hack.
reward = 0
powerup_old = R(0x754)
status_old = R(0x756)
coins_old = R(0x7ED) * 10 + R(0x7EE)
score_old = get_score()
-- set lives to 0. you only got one shot!
-- unless you get a 1-up, in which case, please continue!
W(0x75A, 1)
W(0x75A, 1) -- set number of lives. (mario gets n+1 chances)
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time)
@ -699,7 +699,6 @@ local function do_reset()
if once then
savestate.load(startsave)
--print("end of trial reward:", reward)
else
savestate.save(startsave)
end
@ -716,7 +715,7 @@ local function do_reset()
end
local function init()
network = make_network(input_size, learn_start_select and 8 or 6)
network = make_network(input_size)
network:reset()
network:print()
print("parameters:", network.n_param)
@ -754,7 +753,9 @@ local function doit(dummy)
gui.text(12, 212, ("%03i,%03i,%03i"):format(tf2,tf1,tf0), '#FFFFFF', '#0000003F')
screen_scroll_delta = screen_scroll_delta + R(0x775)
if dummy == true then
-- don't invoke AI this frame. (keep holding the old inputs)
gui.text(96, 16, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
return
end
@ -870,9 +871,7 @@ while true do
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
while bad_states[get_state()] do
--gui.text(120, 124, ("%02X"):format(R(0xE)), '#FFFFFF', '#0000003F')
-- mash the start button until we have control.
-- TODO: learn this too.
local jp_mash = {
up = false,
down = false,
@ -890,8 +889,7 @@ while true do
advance()
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
-- bit of a hack:
while get_state() == "loading" do advance() end
while get_state() == "loading" do advance() end -- kind of a hack.
state_old = get_state()
end
@ -911,6 +909,8 @@ while true do
W(0x75A, 1)
end
-- FIXME: if the game lags then we might miss our frame to change inputs!
-- don't rely on emu.framecount.
local doot = jp == nil or emu.framecount() % frameskip == 0
doit(not doot)