From 9017af0d13f6c752e263f5e4016c48f8aef4d216 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Thu, 29 Jun 2017 09:50:33 +0000 Subject: [PATCH] various work --- main.lua | 112 +++++++++++++++++++++++++++++++++---------------------- nn.lua | 20 ++++++++++ 2 files changed, 88 insertions(+), 44 deletions(-) diff --git a/main.lua b/main.lua index 1c4080f..5d8ecde 100644 --- a/main.lua +++ b/main.lua @@ -22,13 +22,19 @@ end -- configuration and globals. -math.randomseed(10) +--randomseed(11) -local learning_rate = 1e-2 -local deviation = 2e-2 +local enable_overlay = false +local enable_network = true +-- +local epoch_trials = 24 +local learning_rate = 3.2e-3 +local deviation = 1e-1 / epoch_trials +-- +local timer_loser = 1/3 +local cap_time = 400 local epoch_i = 0 -local epoch_trials = 12 local base_params local trial_i = 0 local trial_noise = {} @@ -38,17 +44,16 @@ local trials_remaining = 0 local force_start = false local force_start_old = false -local enable_overlay = false -local enable_network = true - local startsave = savestate.create(1) local poketime = false +local max_time local sprite_input = {} local tile_input = {} local reward + local powerup_old local status_old local coins_old @@ -89,8 +94,10 @@ local ceil = math.ceil local min = math.min local max = math.max local exp = math.exp +local log = math.log local sqrt = math.sqrt local random = math.random +local randomseed = math.randomseed local insert = table.insert local remove = table.remove local unpack = table.unpack or unpack @@ -177,12 +184,15 @@ local function make_network(input_size, buttons) nn_x = nn.Input(input_size) nn_y = nn_x nn_z = {} - nn_y = nn_y:feed(nn.Dense(input_size)) - nn_y = nn_y:feed(nn.Relu()) - --nn_y = nn_y:feed(nn.Dense(floor(input_size / 16))) - --nn_y = nn_y:feed(nn.Relu()) - --nn_y = nn_y:feed(nn.Dense(floor(input_size / 16))) - --nn_y = nn_y:feed(nn.Relu()) + if false then + nn_y = nn_y:feed(nn.Dense(input_size)) + nn_y = nn_y:feed(nn.Gelu()) + else + nn_y = nn_y:feed(nn.Dense(114)) + nn_y = nn_y:feed(nn.Relu()) + nn_y = nn_y:feed(nn.Dense(57)) + nn_y = nn_y:feed(nn.Relu()) + end for i = 1, buttons do nn_z[i] = nn_y nn_z[i] = nn_z[i]:feed(nn.Dense(2)) @@ -451,7 +461,7 @@ local function learn_from_epoch() print() print('rewards:', trial_rewards) normalize(trial_rewards) - print('normalized:', trial_rewards) + --print('normalized:', trial_rewards) local reward_mean, reward_dev = calc_mean_dev(trial_rewards) @@ -465,7 +475,7 @@ local function learn_from_epoch() end local magnitude = learning_rate / deviation - print('stepping with magnitude', magnitude) + --print('stepping with magnitude', magnitude) -- throw the division from the averaging in there too. local altogether = magnitude / epoch_trials for i, v in ipairs(step) do @@ -473,19 +483,22 @@ local function learn_from_epoch() end local step_mean, step_dev = calc_mean_dev(step) - print("step mean:", step_mean) + if step_dev < 1e-8 then + -- we didn't get anywhere. step in a random direction. + print("stepping randomly.") + local noise = trial_noise[1] + local devsqrt = sqrt(deviation) + for i, v in ipairs(step) do + step[i] = devsqrt * noise[i] + end + + step_mean, step_dev = calc_mean_dev(step) + end + if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end print("step stddev:", step_dev) - if step_dev > 1e-8 then - for i, v in ipairs(base_params) do - base_params[i] = v + step[i] - end - else - -- we didn't get anywhere. step in a random direction. - local noise = trial_noise[1] - for i, v in ipairs(base_params) do - base_params[i] = v + magnitude * noise[i] - end + for i, v in ipairs(base_params) do + base_params[i] = v + step[i] end network:distribute(base_params) @@ -513,8 +526,9 @@ local function load_next_trial() print('loading trial', trial_i) local W = nn.copy(base_params) local noise = trial_noise[trial_i] + local devsqrt = sqrt(deviation) for i, v in ipairs(base_params) do - W[i] = v + deviation * noise[i] + W[i] = v + devsqrt * noise[i] end network:distribute(W) end @@ -525,19 +539,11 @@ local function do_reset() if trial_i > 0 then trial_rewards[trial_i] = reward end if epoch_i == 0 or trial_i == epoch_trials then - if epoch_i > 0 then learn_from_epoch() else network:reset() end + if epoch_i > 0 then learn_from_epoch() end epoch_i = epoch_i + 1 prepare_epoch() end - if once then - savestate.load(startsave) - print("end of trial reward:", reward) - else - savestate.save(startsave) - end - once = true - -- bit of a hack: if get_state() == 'loading' then advance() end reward = 0 @@ -549,9 +555,20 @@ local function do_reset() -- unless you get a 1-up, in which case, please continue! W(0x75A, 0) + --max_time = min(log(epoch_i) * 10 + 100, cap_time) + max_time = min(8 * sqrt(15 * (epoch_i - 1)) + 100, cap_time) + + if once then + savestate.load(startsave) + --print("end of trial reward:", reward) + else + savestate.save(startsave) + end + once = true + emu.frameadvance() -- prevents emulator from quirking up. - print() + --print() load_next_trial() reset = false @@ -564,7 +581,11 @@ local function init() emu.poweron() emu.unpause() - emu.speedmode("normal") + emu.speedmode("turbo") + + network:load() + local res, err = pcall(network.load, network) + if res == false then print(err) end end init() @@ -647,6 +668,8 @@ while true do local flagpole_bonus = R(0xE) == 4 and 1 or 0 local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus + -- TODO: add ingame score to reward. + if not ingame_paused then reward = reward + reward_delta end --gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F') @@ -655,7 +678,7 @@ while true do gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F') if get_state() == 'dead' and state_old ~= 'dead' then - print("dead. lives remaining:", R(0x75A, 0)) + --print("dead. lives remaining:", R(0x75A, 0)) if R(0x75A, 0) == 0 then reset = true end end if get_state() == 'lose' then @@ -664,17 +687,18 @@ while true do end -- lose a point for every frame paused. - if ingame_paused then reward = reward - 1 end + --if ingame_paused then reward = reward - 1 end + if ingame_paused then reward = reward - 402; reset = true end -- every few frames mario stands still, forcibly decrease the timer. -- this includes having the game paused. -- TODO: more robust. doesn't detect moonwalking against a wall. local timer = get_timer() - local timer_loser = 1/5 - if ingame_paused or math.random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then - timer = clamp(timer - 1, 0, 400) - set_timer(timer) + if ingame_paused or random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then + timer = timer - 1 end + timer = clamp(timer, 0, max_time) + set_timer(timer) -- if we've run out of time while the game is paused... -- that's cheating! unpause. diff --git a/nn.lua b/nn.lua index f171067..2943f75 100644 --- a/nn.lua +++ b/nn.lua @@ -124,6 +124,7 @@ local Layer = Base:extend() local Model = Base:extend() local Input = Layer:extend() local Relu = Layer:extend() +local Gelu = Layer:extend() local Dense = Layer:extend() local Softmax = Layer:extend() @@ -241,6 +242,24 @@ function Relu:forward(X) return Y end +function Gelu:init() + Layer.init(self, "Gelu") +end + +function Gelu:forward(X) + assert(#X == self.size_in) + self.cache = self.cache or zeros(self.size_out) + local Y = self.cache + + -- NOTE: approximate form of GELU exploiting similarities to sigmoid curve. + for i = 1, #X do + Y[i] = X[i] / (1 + exp(-1.704 * X[i])) + end + + assert(#Y == self.size_out) + return Y +end + function Dense:init(dim) Layer.init(self, "Dense") assert(type(dim) == "number") @@ -413,6 +432,7 @@ return { Model = Model, Input = Input, Relu = Relu, + Gelu = Gelu, Dense = Dense, Softmax = Softmax, }