diff --git a/main.lua b/main.lua index cf194ee..a94cbf4 100644 --- a/main.lua +++ b/main.lua @@ -1,21 +1,19 @@ -- hacks for FCEUX being dumb. local _error = error local _assert = assert -local function error_(msg, level) +error = function(msg, level) if level == nil then level = 1 end print() print(debug.traceback(msg, 1 + level):gsub("\n", "\r\n")) _error(msg, level) end -local function assert_(cond, msg) +assert = function(cond, msg) if cond then return cond end msg = msg or "nondescript" print() print(debug.traceback(msg, 2):gsub("\n", "\r\n")) _error("assertion failed!") end -rawset(_G, 'error', error_) -rawset(_G, 'assert', assert_) -- be strict about globals. local mt = getmetatable(_G) @@ -34,10 +32,9 @@ local frameskip = 4 -- true greedy epsilon has both deterministic and det_epsilon set. local deterministic = true -- use argmax on outputs instead of random sampling. local det_epsilon = true -- take random actions with probability eps. --- using parameters from DQN local eps_start = 1.0 * frameskip / 64 local eps_stop = 0.1 * eps_start -local eps_frames = 1000000 +local eps_frames = 2000000 local learn_start_select = false -- local epoch_trials = 40 @@ -714,7 +711,7 @@ local function doit(dummy) local tf0 = total_frames % 1000 local tf1 = (total_frames % 1000000 - tf0) / 1000 - local tf2 = (total_frames - tf0 - tf1) / 10000000 + local tf2 = (total_frames - tf0 - tf1) / 1000000 gui.text(12, 212, ("%03i,%03i,%03i"):format(tf2,tf1,tf0), '#FFFFFF', '#0000003F') screen_scroll_delta = screen_scroll_delta + R(0x775) @@ -723,8 +720,6 @@ local function doit(dummy) return end - total_frames = total_frames + frameskip - empty(sprite_input) empty(tile_input) empty(extra_input) @@ -795,6 +790,8 @@ local function doit(dummy) nn.reshape(tile_input, 1, tile_count) if enable_network and get_state() == 'playing' or ingame_paused then + total_frames = total_frames + frameskip + local choose = deterministic and argmax2 or rchoice2 local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input}) diff --git a/nn.lua b/nn.lua index 72c9789..32ff0bd 100644 --- a/nn.lua +++ b/nn.lua @@ -474,6 +474,7 @@ end function Relu:forward(X) local bs = checkshape(X, self.shape_in) + if bs ~= self.bs then self:reset_cache(bs) end local Y = self.cache for i = 1, #X do Y[i] = X[i] >= 0 and X[i] or 0 end @@ -676,8 +677,8 @@ function Embed:forward(X) local xi = x * self.dim for j = 1, self.dim do Y[yi+j] = self.lut[xi + j] - yi = yi + 1 end + yi = yi + self.dim end checkshape(Y, self.shape_out)