diff --git a/config.lua b/config.lua new file mode 100644 index 0000000..a619aa8 --- /dev/null +++ b/config.lua @@ -0,0 +1,48 @@ +local function approx_cossim(dim) + return math.pow(1.521 * dim - 0.521, -0.5026) +end + +local cfg = { + defer_prints = true, + + playable_mode = false, + start_big = false, --true + starting_lives = 0, --1 + -- + init_zeros = true, -- instead of he_normal noise or whatever. + frameskip = 4, + -- true greedy epsilon has both deterministic and det_epsilon set. + deterministic = true, -- use argmax on outputs instead of random sampling. + det_epsilon = false, -- take random actions with probability eps. + -- + epoch_trials = 50, + epoch_top_trials = 25, -- new with ARS. + unperturbed_trial = true, -- do a trial without any noise. + negate_trials = true, -- try pairs of normal and negated noise directions. + time_inputs = true, -- binary inputs of global frame count + -- ^ note that this now doubles the effective trials. + deviation = 0.05, --0.075 --0.1 + --learning_rate = 0.01 / approx_cossim(7051) + learning_rate = 1.0, + --learning_rate = 0.0032 / approx_cossim(66573) + --learning_rate = 0.0056 / approx_cossim(66573) + weight_decay = 0.00032, --0.001 --0.0023 + -- + cap_time = 200, --400 + timer_loser = 1/2, + decrement_reward = false, -- bad idea, encourages mario to kill himself +} + +cfg.epoch_top_trials = math.min(cfg.epoch_trials, cfg.epoch_top_trials) + +cfg.eps_start = 1.0 * cfg.frameskip / 64 +cfg.eps_stop = 0.1 * cfg.eps_start +cfg.eps_frames = 1000000 +cfg.enable_overlay = cfg.playable_mode +cfg.enable_network = not cfg.playable_mode + +return setmetatable(cfg, { + __index = function(t, n) + error("cannot use undeclared config '" .. tostring(n) .. "'", 2) + end +}) diff --git a/gameconfig.lua b/gameconfig.lua new file mode 100644 index 0000000..3ec69ac --- /dev/null +++ b/gameconfig.lua @@ -0,0 +1,107 @@ +local gcfg = { + input_size = 60 + 4, -- TODO: let the script figure this out for us. + tile_count = 17 * 13, + + ok_routines = { + [0x4] = true, -- sliding down flagpole + [0x5] = true, -- end of level auto-walk + [0x7] = true, -- start of level auto-walk + [0x8] = true, -- normal (in control) + [0x9] = true, -- acquiring mushroom + [0xA] = true, -- losing big mario + [0xB] = true, -- uhh + [0xC] = true, -- acquiring fireflower + }, + + bad_states = { + power = true, + waiting_demo = true, + playing_demo = true, + unknown = true, + lose = true, + }, + + jp_lut = { + { -- none + up = false, down = false, left = false, right = false, + select = false, start = false, B = false, A = false, + }, { -- A + up = false, down = false, left = false, right = false, + select = false, start = false, B = false, A = true, + }, { -- L + up = false, down = false, left = true, right = false, + select = false, start = false, B = false, A = false, + }, { -- R + up = false, down = false, left = false, right = true, + select = false, start = false, B = false, A = false, + }, { -- L + B + up = false, down = false, left = true, right = false, + select = false, start = false, B = true, A = false, + }, { -- R + B + up = false, down = false, left = false, right = true, + select = false, start = false, B = true, A = false, + }, { -- L + A + up = false, down = false, left = true, right = false, + select = false, start = false, B = false, A = true, + }, { -- R + A + up = false, down = false, left = false, right = true, + select = false, start = false, B = false, A = true, + }, { -- L + A + B + up = false, down = false, left = true, right = false, + select = false, start = false, B = true, A = true, + }, { -- R + A + B + up = false, down = false, left = false, right = true, + select = false, start = false, B = true, A = true, + }, { -- D + up = false, down = true, left = false, right = false, + select = false, start = false, B = false, A = false, + }, { -- D + A + up = false, down = true, left = false, right = false, + select = false, start = false, B = false, A = true, + }, { -- U + up = true, down = false, left = false, right = false, + select = false, start = false, B = false, A = false, + }, + }, + + rotation_offsets = { -- FIXME: not all of these are pixel-perfect. + 0, -40, -- 0x00 + 6, -38, + 15, -37, + 22, -32, + 28, -28, + 32, -22, + 37, -14, + 39, -6, + 40, 0, -- 0x08 + 38, 7, + 37, 15, + 33, 23, + 27, 29, + 22, 33, + 14, 37, + 6, 39, + 0, 41, -- 0x10 + -7, 40, + -16, 38, + -22, 34, + -28, 28, + -34, 23, + -38, 16, + -40, 8, + -40, -0, -- 0x18 + -40, -6, + -38, -14, + -34, -22, + -28, -28, + -22, -32, + -16, -36, + -8, -38, + }, +} + +return setmetatable(gcfg, { + __index = function(t, n) + error("cannot use undeclared gameconfig '" .. tostring(n) .. "'", 2) + end +}) diff --git a/main.lua b/main.lua index 1ace9bf..65f89bb 100644 --- a/main.lua +++ b/main.lua @@ -26,108 +26,8 @@ local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end --randomseed(11) -local defer_prints = true - -local playable_mode = false -local start_big = false --true -local starting_lives = 0 --1 --- -local init_zeros = true -- instead of he_normal noise or whatever. -local frameskip = 4 --- true greedy epsilon has both deterministic and det_epsilon set. -local deterministic = true -- use argmax on outputs instead of random sampling. -local det_epsilon = false -- take random actions with probability eps. -local eps_start = 1.0 * frameskip / 64 -local eps_stop = 0.1 * eps_start -local eps_frames = 1000000 --- -local epoch_trials = 50 -local epoch_top_trials = 25 -- new with ARS. -local unperturbed_trial = true -- do a trial without any noise. -local negate_trials = true -- try pairs of normal and negated noise directions. -local time_inputs = true -- binary inputs of global frame count --- ^ note that this now doubles the effective trials. -local deviation = 0.05 --0.075 --0.1 -local function approx_cossim(dim) - return math.pow(1.521 * dim - 0.521, -0.5026) -end ---local learning_rate = 0.01 / approx_cossim(7051) -local learning_rate = 1.0 ---local learning_rate = 0.0032 / approx_cossim(66573) ---local learning_rate = 0.0056 / approx_cossim(66573) -local weight_decay = 0.00032 --0.001 --0.0023 --- -local cap_time = 200 --400 -local timer_loser = 1/2 -local decrement_reward = false -- bad idea, encourages mario to kill himself --- -local enable_overlay = playable_mode -local enable_network = not playable_mode - -local input_size = 60 + 4 -- TODO: let the script figure this out for us. -local tile_count = 17 * 13 - -local ok_routines = { - [0x4] = true, -- sliding down flagpole - [0x5] = true, -- end of level auto-walk - [0x7] = true, -- start of level auto-walk - [0x8] = true, -- normal (in control) - [0x9] = true, -- acquiring mushroom - [0xA] = true, -- losing big mario - [0xB] = true, -- uhh - [0xC] = true, -- acquiring fireflower -} - -local bad_states = { - power = true, - waiting_demo = true, - playing_demo = true, - unknown = true, - lose = true, -} - -local jp_lut = { - { -- none - up = false, down = false, left = false, right = false, - select = false, start = false, B = false, A = false, - }, { -- A - up = false, down = false, left = false, right = false, - select = false, start = false, B = false, A = true, - }, { -- L - up = false, down = false, left = true, right = false, - select = false, start = false, B = false, A = false, - }, { -- R - up = false, down = false, left = false, right = true, - select = false, start = false, B = false, A = false, - }, { -- L + B - up = false, down = false, left = true, right = false, - select = false, start = false, B = true, A = false, - }, { -- R + B - up = false, down = false, left = false, right = true, - select = false, start = false, B = true, A = false, - }, { -- L + A - up = false, down = false, left = true, right = false, - select = false, start = false, B = false, A = true, - }, { -- R + A - up = false, down = false, left = false, right = true, - select = false, start = false, B = false, A = true, - }, { -- L + A + B - up = false, down = false, left = true, right = false, - select = false, start = false, B = true, A = true, - }, { -- R + A + B - up = false, down = false, left = false, right = true, - select = false, start = false, B = true, A = true, - }, { -- D - up = false, down = true, left = false, right = false, - select = false, start = false, B = false, A = false, - }, { -- D + A - up = false, down = true, left = false, right = false, - select = false, start = false, B = false, A = true, - }, { -- U - up = true, down = false, left = false, right = false, - select = false, start = false, B = false, A = false, - }, -} +local cfg = require("config") +local gcfg = require("gameconfig") -- state. @@ -302,17 +202,21 @@ local network local nn_x, nn_tx, nn_ty, nn_y, nn_z local function make_network(input_size) nn_x = nn.Input({input_size}) - nn_tx = nn.Input({tile_count}) + nn_tx = nn.Input({gcfg.tile_count}) nn_ty = nn_tx:feed(nn.Embed(256, 2)) nn_y = nn.Merge() nn_x:feed(nn_y) nn_ty:feed(nn_y) nn_y = nn_y:feed(nn.Dense(128)) - nn_y = nn_y:feed(nn.Gelu()) + if cfg.deterministic then + nn_y = nn_y:feed(nn.Relu()) + else + nn_y = nn_y:feed(nn.Gelu()) + end nn_z = nn_y - nn_z = nn_z:feed(nn.Dense(#jp_lut)) + nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut)) nn_z = nn_z:feed(nn.Softmax()) return nn.Model({nn_x, nn_tx}, {nn_z}) end @@ -322,41 +226,6 @@ end -- disassembly used for reference: -- https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM -local rotation_offsets = { -- FIXME: not all of these are pixel-perfect. - 0, -40, -- 0x00 - 6, -38, - 15, -37, - 22, -32, - 28, -28, - 32, -22, - 37, -14, - 39, -6, - 40, 0, -- 0x08 - 38, 7, - 37, 15, - 33, 23, - 27, 29, - 22, 33, - 14, 37, - 6, 39, - 0, 41, -- 0x10 - -7, 40, - -16, 38, - -22, 34, - -28, 28, - -34, 23, - -38, 16, - -40, 8, - -40, -0, -- 0x18 - -40, -6, - -38, -14, - -34, -22, - -28, -28, - -22, -32, - -16, -36, - -8, -38, -} - local function get_timer() return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA) end @@ -386,7 +255,7 @@ local function mark_sprite(x, y, t) sprite_input[#sprite_input+1] = t end if t == 0 then return end - if enable_overlay then + if cfg.enable_overlay then gui.box(x-4, y-4, x+4, y+4) --gui.text(x-2, y-3, tostring(i), '#FFFFFF', '#00000000') gui.text(x-13, y-3-9, ("%+04i"):format(t), '#FFFFFF', '#0000003F') @@ -397,7 +266,7 @@ end local function mark_tile(x, y, t) tile_input[#tile_input+1] = t if t == 0 then return end - if enable_overlay then + if cfg.enable_overlay then gui.box(x-8, y-8, x+8, y+8) gui.text(x-5, y-3, ("%02X"):format(t), '#FFFFFF', '#00000000') end @@ -479,7 +348,7 @@ local function handle_enemies() -- TODO: handle long fire bars too local rot = R(0xA0 + i) --* 0x100 + R(0x58 + i) gui.text(x-13, y-3+9, ("%04X"):format(rot), '#FFFFFF', '#0000003F') - local x_off, y_off = rotation_offsets[rot*2+1], rotation_offsets[rot*2+2] + local x_off, y_off = gcfg.rotation_offsets[rot*2+1], gcfg.rotation_offsets[rot*2+2] x, y = x + x_off, y + y_off end if invisible then @@ -578,11 +447,11 @@ local function prepare_epoch() empty(trial_rewards) -- TODO: save memory. generate noise as needed by saving the seed -- (the os.time() as of here) and calling nn.normal() each trial. - for i = 1, epoch_trials do + for i = 1, cfg.epoch_trials do local noise = nn.zeros(#base_params) -- NOTE: change in implementation: deviation is multiplied here -- and ONLY here now. - for j = 1, #base_params do noise[j] = deviation * nn.normal() end + for j = 1, #base_params do noise[j] = cfg.deviation * nn.normal() end trial_noise[i] = noise end trial_i = -1 @@ -590,7 +459,7 @@ end local function load_next_pair() trial_i = trial_i + 1 - if trial_i == 0 and not unperturbed_trial then + if trial_i == 0 and not cfg.unperturbed_trial then trial_i = 1 trial_neg = true end @@ -599,7 +468,7 @@ local function load_next_pair() if trial_i > 0 then if trial_neg then - if not defer_prints then print('trial', trial_i, 'positive') end + if not cfg.defer_prints then print('trial', trial_i, 'positive') end local noise = trial_noise[trial_i] for i, v in ipairs(base_params) do W[i] = v + noise[i] @@ -607,7 +476,7 @@ local function load_next_pair() else trial_i = trial_i - 1 - if not defer_prints then print('trial', trial_i, 'positive') end + if not cfg.defer_prints then print('trial', trial_i, 'positive') end local noise = trial_noise[trial_i] for i, v in ipairs(base_params) do W[i] = v - noise[i] @@ -616,17 +485,17 @@ local function load_next_pair() trial_neg = not trial_neg else - if not defer_prints then print("test trial") end + if not cfg.defer_prints then print("test trial") end end network:distribute(W) end local function load_next_trial() - if negate_trials then return load_next_pair() end + if cfg.negate_trials then return load_next_pair() end trial_i = trial_i + 1 local W = nn.copy(base_params) - if trial_i == 0 and not unperturbed_trial then + if trial_i == 0 and not cfg.unperturbed_trial then trial_i = 1 end if trial_i > 0 then @@ -682,7 +551,7 @@ local function learn_from_epoch() --for _, v in ipairs(trial_rewards) do insert(all_rewards, v) end - if unperturbed_trial then + if cfg.unperturbed_trial then local nth_place = unperturbed_rank(trial_rewards, trial_rewards[0]) -- a rank of 1 means our gradient is uninformative. @@ -694,10 +563,10 @@ local function learn_from_epoch() -- new stuff local best_rewards - if negate_trials then + if cfg.negate_trials then -- select one (the best) reward of each pos/neg pair. best_rewards = {} - for i = 1, epoch_trials do + for i = 1, cfg.epoch_trials do local ind = (i - 1) * 2 + 1 local pos = trial_rewards[ind + 0] local neg = trial_rewards[ind + 1] @@ -712,7 +581,7 @@ local function learn_from_epoch() sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end) --print("indices:", indices) - for i = epoch_top_trials + 1, #best_rewards do indices[i] = nil end + for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end print("best trials:", indices) local top_rewards = {} @@ -731,27 +600,27 @@ local function learn_from_epoch() for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end -- NOTE: step no longer directly incorporates learning_rate. - for i = 1, epoch_trials do + for i = 1, cfg.epoch_trials do local ind = (i - 1) * 2 + 1 local pos = top_rewards[ind + 0] local neg = top_rewards[ind + 1] local reward = pos - neg local noise = trial_noise[i] for j, v in ipairs(noise) do - step[j] = step[j] + reward * v / epoch_top_trials + step[j] = step[j] + reward * v / cfg.epoch_top_trials end end local step_mean, step_dev = calc_mean_dev(step) if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end --print("step stddev:", step_dev) - print("full step stddev:", learning_rate * step_dev) + print("full step stddev:", cfg.learning_rate * step_dev) for i, v in ipairs(base_params) do - base_params[i] = v + learning_rate * step[i] - weight_decay * v + base_params[i] = v + cfg.learning_rate * step[i] - cfg.weight_decay * v end - if enable_network then + if cfg.enable_network then network:distribute(base_params) network:save() else @@ -782,10 +651,10 @@ local function do_reset() -- be a little more descriptive. if state == 'dead' and get_timer() == 0 then state = 'timeup' end - if trial_i >= 0 and defer_prints then + if trial_i >= 0 and cfg.defer_prints then if trial_i == 0 then print('test trial reward:', reward, "("..state..")") - elseif negate_trials then + elseif cfg.negate_trials then --local dir = trial_neg and "negative" or "positive" --print('trial', trial_i, dir, 'reward:', reward, "("..state..")") @@ -804,14 +673,14 @@ local function do_reset() end if trial_i >= 0 then - if trial_i == 0 or not negate_trials then + if trial_i == 0 or not cfg.negate_trials then trial_rewards[trial_i] = reward else trial_rewards[#trial_rewards + 1] = reward end end - if epoch_i == 0 or (trial_i == epoch_trials and trial_neg) then + if epoch_i == 0 or (trial_i == cfg.epoch_trials and trial_neg) then if epoch_i > 0 then learn_from_epoch() end epoch_i = epoch_i + 1 prepare_epoch() @@ -825,19 +694,19 @@ local function do_reset() score_old = get_score() -- set number of lives. (mario gets n+1 chances) - W(0x75A, starting_lives) + W(0x75A, cfg.starting_lives) - if start_big then + if cfg.start_big then -- make mario "super". W(0x754, 0) W(0x756, 1) end - --max_time = min(log(epoch_i) * 10 + 100, cap_time) - --max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time) - max_time = min(6 * sqrt(480 / epoch_trials * (epoch_i - 1)) + 60, cap_time) + --max_time = min(log(epoch_i) * 10 + 100, cfg.cap_time) + --max_time = min(8 * sqrt(360 / cfg.epoch_trials * (epoch_i - 1)) + 100, cfg.cap_time) + max_time = min(6 * sqrt(480 / cfg.epoch_trials * (epoch_i - 1)) + 60, cfg.cap_time) max_time = ceil(max_time) - --max_time = cap_time + --max_time = cfg.cap_time if once then savestate.load(startsave) @@ -856,12 +725,12 @@ local function do_reset() end local function init() - network = make_network(input_size) + network = make_network(gcfg.input_size) network:reset() network:print() print("parameters:", network.n_param) - if init_zeros then + if cfg.init_zeros then local W = network:collect() for i, w in ipairs(W) do W[i] = 0 end network:distribute(W) @@ -889,11 +758,11 @@ local function doit(dummy) -- TODO: more robust. doesn't detect moonwalking against a wall. -- well, that shouldn't happen anymore now that i've disabled left+right. local timer = get_timer() - if ingame_paused or random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then + if ingame_paused or random() > 1 - cfg.timer_loser and R(0x1D) == 0 and R(0x57) == 0 then timer = timer - 1 end timer = clamp(timer, 0, max_time) - if enable_network then + if cfg.enable_network then set_timer(timer) end @@ -924,7 +793,7 @@ local function doit(dummy) insert(extra_input, vx) insert(extra_input, vy) - if time_inputs then + if cfg.time_inputs then for i=2,5 do insert(extra_input, band(total_frames, lshift(1, i))) end @@ -946,14 +815,14 @@ local function doit(dummy) local powerup_delta = powerup_old - powerup -- 2 is fire mario. local status_delta = clamp(status - status_old, -1, 1) - local flagpole_bonus = R(0xE) == 4 and frameskip or 0 + local flagpole_bonus = R(0xE) == 4 and cfg.frameskip or 0 --local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus local score_delta = get_score() - score_old if score_delta < 0 then score_delta = 0 end local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus screen_scroll_delta = 0 - if decrement_reward and reward_delta == 0 then reward_delta = reward_delta - 1 end + if cfg.decrement_reward and reward_delta == 0 then reward_delta = reward_delta - 1 end if not ingame_paused then reward = reward + reward_delta end @@ -981,22 +850,22 @@ local function doit(dummy) local X = {} for i, v in ipairs(sprite_input) do insert(X, v / 256) end for i, v in ipairs(extra_input) do insert(X, v / 256) end - nn.reshape(X, 1, input_size) - nn.reshape(tile_input, 1, tile_count) + nn.reshape(X, 1, gcfg.input_size) + nn.reshape(tile_input, 1, gcfg.tile_count) - if enable_network and get_state() == 'playing' or ingame_paused then - total_frames = total_frames + frameskip + if cfg.enable_network and get_state() == 'playing' or ingame_paused then + total_frames = total_frames + cfg.frameskip local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input}) - local eps = lerp(eps_start, eps_stop, total_frames / eps_frames) - if det_epsilon and random() < eps then - local i = floor(random() * #jp_lut) + 1 - jp = nn.copy(jp_lut[i], jp) + local eps = lerp(cfg.eps_start, cfg.eps_stop, total_frames / cfg.eps_frames) + if cfg.det_epsilon and random() < eps then + local i = floor(random() * #gcfg.jp_lut) + 1 + jp = nn.copy(gcfg.jp_lut[i], jp) else - local choose = deterministic and argmax or softchoice + local choose = cfg.deterministic and argmax or softchoice local ind = choose(unpack(outputs[nn_z])) - jp = nn.copy(jp_lut[ind], jp) + jp = nn.copy(gcfg.jp_lut[ind], jp) end if force_start then @@ -1026,7 +895,7 @@ init() while true do gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F') - while bad_states[get_state()] do + while gcfg.bad_states[get_state()] do -- mash the start button until we have control. joypad_mash('start') reset = true @@ -1040,7 +909,7 @@ while true do if reset then do_reset() end - if not enable_network then + if not cfg.enable_network then -- infinite time cheat. super handy for testing. if R(0xE) == 8 then set_timer(667) @@ -1056,7 +925,7 @@ while true do -- FIXME: if the game lags then we might miss our frame to change inputs! -- don't rely on emu.framecount. - local doot = jp == nil or emu.framecount() % frameskip == 0 + local doot = jp == nil or emu.framecount() % cfg.frameskip == 0 doit(not doot) -- jp might still be nil if we're not ingame or we're not playing.