diff --git a/main.lua b/main.lua index 5d8ecde..15c1836 100644 --- a/main.lua +++ b/main.lua @@ -24,15 +24,27 @@ end --randomseed(11) -local enable_overlay = false -local enable_network = true +local playable_mode = false -- -local epoch_trials = 24 -local learning_rate = 3.2e-3 -local deviation = 1e-1 / epoch_trials +local deterministic = false -- use argmax on outputs instead of random sampling. +local det_epsilon = true -- take random actions with probability eps. +local eps_start = 0.50 +local eps_stop = 0.05 +local eps_frames = 60*60*60 +local consider_past_rewards = false +local learn_start_select = false +-- +local epoch_trials = 40 -- 24 +local learning_rate = 1e-3 +local deviation = 1e-2 -- 4e-3 -- -local timer_loser = 1/3 local cap_time = 400 +local timer_loser = 1/3 +-- +local enable_overlay = playable_mode +local enable_network = not playable_mode + +local input_size = 281 -- TODO: let the script figure this out for us. local epoch_i = 0 local base_params @@ -41,6 +53,9 @@ local trial_noise = {} local trial_rewards = {} local trials_remaining = 0 +local trial_frames = 0 +local total_frames = 0 + local force_start = false local force_start_old = false @@ -51,12 +66,15 @@ local max_time local sprite_input = {} local tile_input = {} +local extra_input = {} local reward +local all_rewards = {} local powerup_old local status_old local coins_old +local score_old local once = false local reset = true @@ -125,6 +143,8 @@ end local function clamp(x, l, u) return min(max(x, l), u) end +local function lerp(a, b, t) return a + (b - a) * clamp(t, 0, 1) end + local function argmax(...) local max_i = 0 local max_v = -999999999 @@ -142,6 +162,14 @@ local function argmax2(t) return t[1] > t[2] end +local function rchoice2(t) + return t[1] > random() +end + +local function rbool(t) + return 0.5 >= random() +end + local function empty(t) for k, _ in pairs(t) do t[k] = nil end return t @@ -171,6 +199,15 @@ local function normalize(x, out) return out end +local function normalize_wrt(x, s, out) + out = out or x + local mean, dev = calc_mean_dev(s) + if dev <= 0 then dev = 1 end + local devs = sqrt(dev) + for i, v in ipairs(x) do out[i] = (v - mean) / devs end + return out +end + -- game-agnostic stuff (i.e. the network itself) package.loaded['nn'] = nil -- DEBUG @@ -188,10 +225,12 @@ local function make_network(input_size, buttons) nn_y = nn_y:feed(nn.Dense(input_size)) nn_y = nn_y:feed(nn.Gelu()) else - nn_y = nn_y:feed(nn.Dense(114)) - nn_y = nn_y:feed(nn.Relu()) - nn_y = nn_y:feed(nn.Dense(57)) - nn_y = nn_y:feed(nn.Relu()) + nn_y = nn_y:feed(nn.Dense(128)) + nn_y = nn_y:feed(nn.Gelu()) + nn_y = nn_y:feed(nn.Dense(64)) + nn_y = nn_y:feed(nn.Gelu()) + nn_y = nn_y:feed(nn.Dense(48)) + nn_y = nn_y:feed(nn.Gelu()) end for i = 1, buttons do nn_z[i] = nn_y @@ -247,6 +286,14 @@ local function get_timer() return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA) end +local function get_score() + return R(0x7DE) * 10000 + + R(0x7DF) * 1000 + + R(0x7E0) * 100 + + R(0x7E1) * 10 + + R(0x7E2) +end + local function set_timer(time) W(0x7F8, floor(time / 100)) W(0x7F9, floor((time / 10) % 10)) @@ -460,7 +507,16 @@ end local function learn_from_epoch() print() print('rewards:', trial_rewards) - normalize(trial_rewards) + + for _, v in ipairs(trial_rewards) do + insert(all_rewards, v) + end + + if consider_past_rewards then + normalize_wrt(trial_rewards, all_rewards) + else + normalize(trial_rewards) + end --print('normalized:', trial_rewards) local reward_mean, reward_dev = calc_mean_dev(trial_rewards) @@ -501,9 +557,12 @@ local function learn_from_epoch() base_params[i] = v + step[i] end - network:distribute(base_params) - - network:save() + if enable_network then + network:distribute(base_params) + network:save() + else + print("note: not updating weights in playable mode.") + end print() end @@ -534,7 +593,7 @@ local function load_next_trial() end local function do_reset() - print("resetting in state:", get_state()) + print("resetting in state: "..get_state()..". reward:", reward) if trial_i > 0 then trial_rewards[trial_i] = reward end @@ -550,13 +609,15 @@ local function do_reset() powerup_old = R(0x754) status_old = R(0x756) coins_old = R(0x7ED) * 10 + R(0x7EE) + score_old = get_score() -- set lives to 0. you only got one shot! -- unless you get a 1-up, in which case, please continue! W(0x75A, 0) --max_time = min(log(epoch_i) * 10 + 100, cap_time) - max_time = min(8 * sqrt(15 * (epoch_i - 1)) + 100, cap_time) + max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time) + max_time = ceil(max_time) if once then savestate.load(startsave) @@ -575,7 +636,7 @@ local function do_reset() end local function init() - network = make_network(279, 8) + network = make_network(input_size, learn_start_select and 8 or 6) network:reset() print("parameters:", network.n_param) @@ -583,13 +644,14 @@ local function init() emu.unpause() emu.speedmode("turbo") - network:load() local res, err = pcall(network.load, network) if res == false then print(err) end end init() +local dummy_softmax_values = {0, 0} + while true do gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F') @@ -638,6 +700,7 @@ while true do empty(sprite_input) empty(tile_input) + empty(extra_input) -- player -- TODO: check if mario is playable. @@ -646,6 +709,10 @@ while true do local status = R(0x756) mark_sprite(x + 8, y + 24, -powerup - 1) + local vx, vy = S(0x57), S(0x9F) + insert(extra_input, vx) + insert(extra_input, vy) + handle_enemies() handle_fireballs() -- blocks being hit. not interactable; we don't care! @@ -666,7 +733,10 @@ while true do local status_delta = clamp(status - status_old, -1, 1) local screen_scroll_delta = R(0x775) local flagpole_bonus = R(0xE) == 4 and 1 or 0 - local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus + --local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus + local score_delta = get_score() - score_old + if score_delta < 0 then score_delta = 0 end + local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus -- TODO: add ingame score to reward. @@ -674,8 +744,9 @@ while true do --gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F') --gui.text(4, 22, ("%02X"):format(#tile_input), '#FFFFFF', '#0000003F') - gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F') - gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F') + --gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F') + --gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F') + gui.text(96, 16, ("%+4i"):format(reward), '#FFFFFF', '#0000003F') if get_state() == 'dead' and state_old ~= 'dead' then --print("dead. lives remaining:", R(0x75A, 0)) @@ -698,7 +769,9 @@ while true do timer = timer - 1 end timer = clamp(timer, 0, max_time) - set_timer(timer) + if enable_network then + set_timer(timer) + end -- if we've run out of time while the game is paused... -- that's cheating! unpause. @@ -707,10 +780,19 @@ while true do local X = {} -- TODO: cache. for i, v in ipairs(sprite_input) do insert(X, v / 256) end for i, v in ipairs(tile_input) do insert(X, v / 256) end - --error(#X) + for i, v in ipairs(extra_input) do insert(X, v / 256) end + if #X ~= input_size then error("input size should be: "..tostring(#X)) end if enable_network and get_state() == 'playing' or ingame_paused then + local choose = deterministic and argmax2 or rchoice2 + local outputs = network:forward(X) + + -- TODO: predict the *rewards* of all possible actions? + -- that's how DQN seems to work anyway. + -- ah, but A3C just returns probabilities, + -- besides the critic? + local softmaxed = { outputs[nn_z[1]], outputs[nn_z[2]], @@ -718,19 +800,29 @@ while true do outputs[nn_z[4]], outputs[nn_z[5]], outputs[nn_z[6]], - outputs[nn_z[7]], - outputs[nn_z[8]], + learn_start_select and outputs[nn_z[7]] or dummy_softmax_values, + learn_start_select and outputs[nn_z[8]] or dummy_softmax_values, } + local jp = { - up = argmax2(softmaxed[1]), - down = argmax2(softmaxed[2]), - left = argmax2(softmaxed[3]), - right = argmax2(softmaxed[4]), - A = argmax2(softmaxed[5]), - B = argmax2(softmaxed[6]), - start = argmax2(softmaxed[7]), - select = argmax2(softmaxed[8]), + up = choose(softmaxed[1]), + down = choose(softmaxed[2]), + left = choose(softmaxed[3]), + right = choose(softmaxed[4]), + A = choose(softmaxed[5]), + B = choose(softmaxed[6]), + start = choose(softmaxed[7]), + select = choose(softmaxed[8]), } + + if det_epsilon then + local eps = lerp(eps_start, eps_stop, total_frames / eps_frames) + for k, v in pairs(jp) do + local ss_ok = k ~= 'start' and k ~= 'select' or learn_start_select + if random() < eps and ss_ok then jp[k] = rbool() end + end + end + if force_start then jp = { up = false, @@ -743,6 +835,7 @@ while true do select = false, } end + joypad.write(1, jp) end @@ -751,5 +844,6 @@ while true do status_old = status force_start_old = force_start state_old = get_state() + score_old = get_score() advance() end diff --git a/nn.lua b/nn.lua index 2943f75..996ef98 100644 --- a/nn.lua +++ b/nn.lua @@ -173,6 +173,10 @@ function Layer:forward_deterministic(...) return self:forward(...) end +function Layer:backward() + error("Unimplemented.") +end + function Layer:_new_weights(init) local w = Weights(init) insert(self.weights, w) @@ -227,6 +231,11 @@ function Input:forward(X) return X end +function Input:backward(dY) + assert(#dY == self.size_out) + return zeros(#dY) +end + function Relu:init() Layer.init(self, "Relu") end @@ -242,6 +251,18 @@ function Relu:forward(X) return Y end +function Relu:backward(dY) + assert(#dY == self.size_out) + self.dcache = self.dcache or zeros(self.size_in) + local Y = self.cache + local dX = self.dcache + + for i = 1, #dY do dX[i] = Y[i] >= 0 and dY[i] or 0 end + + assert(#Y == self.size_in) + return Y +end + function Gelu:init() Layer.init(self, "Gelu") end @@ -383,8 +404,12 @@ function Model:distribute(W) end end +function Model:default_filename() + return ('network%07i.txt'):format(self.n_param) +end + function Model:save(fn) - local fn = fn or 'network.txt' + local fn = fn or self:default_filename() local f = open(fn, 'w') if f == nil then error("Failed to save network to file "..fn) end local W = self:collect() @@ -396,7 +421,7 @@ function Model:save(fn) end function Model:load(fn) - local fn = fn or 'network.txt' + local fn = fn or self:default_filename() local f = open(fn, 'r') if f == nil then error("Failed to load network from file "..fn)