tweaks and fixes

This commit is contained in:
Connor Olding 2017-09-08 10:27:10 +00:00
parent 5a8c0f6140
commit 88dcd203a1
2 changed files with 8 additions and 10 deletions

View file

@ -1,21 +1,19 @@
-- hacks for FCEUX being dumb.
local _error = error
local _assert = assert
local function error_(msg, level)
error = function(msg, level)
if level == nil then level = 1 end
print()
print(debug.traceback(msg, 1 + level):gsub("\n", "\r\n"))
_error(msg, level)
end
local function assert_(cond, msg)
assert = function(cond, msg)
if cond then return cond end
msg = msg or "nondescript"
print()
print(debug.traceback(msg, 2):gsub("\n", "\r\n"))
_error("assertion failed!")
end
rawset(_G, 'error', error_)
rawset(_G, 'assert', assert_)
-- be strict about globals.
local mt = getmetatable(_G)
@ -34,10 +32,9 @@ local frameskip = 4
-- true greedy epsilon has both deterministic and det_epsilon set.
local deterministic = true -- use argmax on outputs instead of random sampling.
local det_epsilon = true -- take random actions with probability eps.
-- using parameters from DQN
local eps_start = 1.0 * frameskip / 64
local eps_stop = 0.1 * eps_start
local eps_frames = 1000000
local eps_frames = 2000000
local learn_start_select = false
--
local epoch_trials = 40
@ -714,7 +711,7 @@ local function doit(dummy)
local tf0 = total_frames % 1000
local tf1 = (total_frames % 1000000 - tf0) / 1000
local tf2 = (total_frames - tf0 - tf1) / 10000000
local tf2 = (total_frames - tf0 - tf1) / 1000000
gui.text(12, 212, ("%03i,%03i,%03i"):format(tf2,tf1,tf0), '#FFFFFF', '#0000003F')
screen_scroll_delta = screen_scroll_delta + R(0x775)
@ -723,8 +720,6 @@ local function doit(dummy)
return
end
total_frames = total_frames + frameskip
empty(sprite_input)
empty(tile_input)
empty(extra_input)
@ -795,6 +790,8 @@ local function doit(dummy)
nn.reshape(tile_input, 1, tile_count)
if enable_network and get_state() == 'playing' or ingame_paused then
total_frames = total_frames + frameskip
local choose = deterministic and argmax2 or rchoice2
local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input})

3
nn.lua
View file

@ -474,6 +474,7 @@ end
function Relu:forward(X)
local bs = checkshape(X, self.shape_in)
if bs ~= self.bs then self:reset_cache(bs) end
local Y = self.cache
for i = 1, #X do Y[i] = X[i] >= 0 and X[i] or 0 end
@ -676,8 +677,8 @@ function Embed:forward(X)
local xi = x * self.dim
for j = 1, self.dim do
Y[yi+j] = self.lut[xi + j]
yi = yi + 1
end
yi = yi + self.dim
end
checkshape(Y, self.shape_out)