diff --git a/main.lua b/main.lua index a9c8d25..c6aa93e 100644 --- a/main.lua +++ b/main.lua @@ -35,7 +35,6 @@ local det_epsilon = true -- take random actions with probability eps. local eps_start = 1.0 * frameskip / 64 local eps_stop = 0.1 * eps_start local eps_frames = 2000000 -local learn_start_select = false -- local epoch_trials = 40 local negate_trials = true -- try pairs of normal and negated noise directions. @@ -224,7 +223,7 @@ local function rchoice2(t) return t[1] > random() end -local function rbool(t) +local function rbool() return 0.5 >= random() end @@ -273,7 +272,7 @@ local nn = require("nn") local network local nn_x, nn_tx, nn_ty, nn_y, nn_z -local function make_network(input_size, buttons) +local function make_network(input_size) nn_x = nn.Input({input_size}) nn_tx = nn.Input({tile_count}) nn_ty = nn_tx:feed(nn.Embed(256, 2)) @@ -288,6 +287,11 @@ local function make_network(input_size, buttons) nn_y = nn_y:feed(nn.Relu()) nn_y = nn_y:feed(nn.Dense(48)) nn_y = nn_y:feed(nn.Relu()) + else + nn_y = nn_y:feed(nn.Dense(128)) + nn_y = nn_y:feed(nn.Gelu()) + nn_y = nn_y:feed(nn.Dense(128)) + nn_y = nn_y:feed(nn.Gelu()) end nn_z = nn_y @@ -471,7 +475,6 @@ local function handle_enemies() end local function handle_fireballs() - -- fireballs for i = 0, 1 do local x, y = getxy(i, 0x8D, 0xD5, 0x74, 0xBC) x, y = x + 4, y + 4 @@ -681,17 +684,14 @@ local function do_reset() prepare_epoch() end - -- bit of a hack: - if get_state() == 'loading' then advance() end + if get_state() == 'loading' then advance() end -- kind of a hack. reward = 0 powerup_old = R(0x754) status_old = R(0x756) coins_old = R(0x7ED) * 10 + R(0x7EE) score_old = get_score() - -- set lives to 0. you only got one shot! - -- unless you get a 1-up, in which case, please continue! - W(0x75A, 1) + W(0x75A, 1) -- set number of lives. (mario gets n+1 chances) --max_time = min(log(epoch_i) * 10 + 100, cap_time) max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time) @@ -699,7 +699,6 @@ local function do_reset() if once then savestate.load(startsave) - --print("end of trial reward:", reward) else savestate.save(startsave) end @@ -716,7 +715,7 @@ local function do_reset() end local function init() - network = make_network(input_size, learn_start_select and 8 or 6) + network = make_network(input_size) network:reset() network:print() print("parameters:", network.n_param) @@ -754,7 +753,9 @@ local function doit(dummy) gui.text(12, 212, ("%03i,%03i,%03i"):format(tf2,tf1,tf0), '#FFFFFF', '#0000003F') screen_scroll_delta = screen_scroll_delta + R(0x775) + if dummy == true then + -- don't invoke AI this frame. (keep holding the old inputs) gui.text(96, 16, ("%+4i"):format(reward), '#FFFFFF', '#0000003F') return end @@ -870,9 +871,7 @@ while true do gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F') while bad_states[get_state()] do - --gui.text(120, 124, ("%02X"):format(R(0xE)), '#FFFFFF', '#0000003F') -- mash the start button until we have control. - -- TODO: learn this too. local jp_mash = { up = false, down = false, @@ -890,8 +889,7 @@ while true do advance() gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F') - -- bit of a hack: - while get_state() == "loading" do advance() end + while get_state() == "loading" do advance() end -- kind of a hack. state_old = get_state() end @@ -911,6 +909,8 @@ while true do W(0x75A, 1) end + -- FIXME: if the game lags then we might miss our frame to change inputs! + -- don't rely on emu.framecount. local doot = jp == nil or emu.framecount() % frameskip == 0 doit(not doot)