cleanup
This commit is contained in:
parent
d384635000
commit
01d7e5e230
1 changed files with 15 additions and 15 deletions
30
main.lua
30
main.lua
|
@ -35,7 +35,6 @@ local det_epsilon = true -- take random actions with probability eps.
|
|||
local eps_start = 1.0 * frameskip / 64
|
||||
local eps_stop = 0.1 * eps_start
|
||||
local eps_frames = 2000000
|
||||
local learn_start_select = false
|
||||
--
|
||||
local epoch_trials = 40
|
||||
local negate_trials = true -- try pairs of normal and negated noise directions.
|
||||
|
@ -224,7 +223,7 @@ local function rchoice2(t)
|
|||
return t[1] > random()
|
||||
end
|
||||
|
||||
local function rbool(t)
|
||||
local function rbool()
|
||||
return 0.5 >= random()
|
||||
end
|
||||
|
||||
|
@ -273,7 +272,7 @@ local nn = require("nn")
|
|||
|
||||
local network
|
||||
local nn_x, nn_tx, nn_ty, nn_y, nn_z
|
||||
local function make_network(input_size, buttons)
|
||||
local function make_network(input_size)
|
||||
nn_x = nn.Input({input_size})
|
||||
nn_tx = nn.Input({tile_count})
|
||||
nn_ty = nn_tx:feed(nn.Embed(256, 2))
|
||||
|
@ -288,6 +287,11 @@ local function make_network(input_size, buttons)
|
|||
nn_y = nn_y:feed(nn.Relu())
|
||||
nn_y = nn_y:feed(nn.Dense(48))
|
||||
nn_y = nn_y:feed(nn.Relu())
|
||||
else
|
||||
nn_y = nn_y:feed(nn.Dense(128))
|
||||
nn_y = nn_y:feed(nn.Gelu())
|
||||
nn_y = nn_y:feed(nn.Dense(128))
|
||||
nn_y = nn_y:feed(nn.Gelu())
|
||||
end
|
||||
|
||||
nn_z = nn_y
|
||||
|
@ -471,7 +475,6 @@ local function handle_enemies()
|
|||
end
|
||||
|
||||
local function handle_fireballs()
|
||||
-- fireballs
|
||||
for i = 0, 1 do
|
||||
local x, y = getxy(i, 0x8D, 0xD5, 0x74, 0xBC)
|
||||
x, y = x + 4, y + 4
|
||||
|
@ -681,17 +684,14 @@ local function do_reset()
|
|||
prepare_epoch()
|
||||
end
|
||||
|
||||
-- bit of a hack:
|
||||
if get_state() == 'loading' then advance() end
|
||||
if get_state() == 'loading' then advance() end -- kind of a hack.
|
||||
reward = 0
|
||||
powerup_old = R(0x754)
|
||||
status_old = R(0x756)
|
||||
coins_old = R(0x7ED) * 10 + R(0x7EE)
|
||||
score_old = get_score()
|
||||
|
||||
-- set lives to 0. you only got one shot!
|
||||
-- unless you get a 1-up, in which case, please continue!
|
||||
W(0x75A, 1)
|
||||
W(0x75A, 1) -- set number of lives. (mario gets n+1 chances)
|
||||
|
||||
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
|
||||
max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time)
|
||||
|
@ -699,7 +699,6 @@ local function do_reset()
|
|||
|
||||
if once then
|
||||
savestate.load(startsave)
|
||||
--print("end of trial reward:", reward)
|
||||
else
|
||||
savestate.save(startsave)
|
||||
end
|
||||
|
@ -716,7 +715,7 @@ local function do_reset()
|
|||
end
|
||||
|
||||
local function init()
|
||||
network = make_network(input_size, learn_start_select and 8 or 6)
|
||||
network = make_network(input_size)
|
||||
network:reset()
|
||||
network:print()
|
||||
print("parameters:", network.n_param)
|
||||
|
@ -754,7 +753,9 @@ local function doit(dummy)
|
|||
gui.text(12, 212, ("%03i,%03i,%03i"):format(tf2,tf1,tf0), '#FFFFFF', '#0000003F')
|
||||
|
||||
screen_scroll_delta = screen_scroll_delta + R(0x775)
|
||||
|
||||
if dummy == true then
|
||||
-- don't invoke AI this frame. (keep holding the old inputs)
|
||||
gui.text(96, 16, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
|
||||
return
|
||||
end
|
||||
|
@ -870,9 +871,7 @@ while true do
|
|||
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
|
||||
|
||||
while bad_states[get_state()] do
|
||||
--gui.text(120, 124, ("%02X"):format(R(0xE)), '#FFFFFF', '#0000003F')
|
||||
-- mash the start button until we have control.
|
||||
-- TODO: learn this too.
|
||||
local jp_mash = {
|
||||
up = false,
|
||||
down = false,
|
||||
|
@ -890,8 +889,7 @@ while true do
|
|||
advance()
|
||||
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
|
||||
|
||||
-- bit of a hack:
|
||||
while get_state() == "loading" do advance() end
|
||||
while get_state() == "loading" do advance() end -- kind of a hack.
|
||||
state_old = get_state()
|
||||
end
|
||||
|
||||
|
@ -911,6 +909,8 @@ while true do
|
|||
W(0x75A, 1)
|
||||
end
|
||||
|
||||
-- FIXME: if the game lags then we might miss our frame to change inputs!
|
||||
-- don't rely on emu.framecount.
|
||||
local doot = jp == nil or emu.framecount() % frameskip == 0
|
||||
doit(not doot)
|
||||
|
||||
|
|
Loading…
Reference in a new issue