tweaks and fixes
This commit is contained in:
parent
5a8c0f6140
commit
88dcd203a1
2 changed files with 8 additions and 10 deletions
15
main.lua
15
main.lua
|
@ -1,21 +1,19 @@
|
|||
-- hacks for FCEUX being dumb.
|
||||
local _error = error
|
||||
local _assert = assert
|
||||
local function error_(msg, level)
|
||||
error = function(msg, level)
|
||||
if level == nil then level = 1 end
|
||||
print()
|
||||
print(debug.traceback(msg, 1 + level):gsub("\n", "\r\n"))
|
||||
_error(msg, level)
|
||||
end
|
||||
local function assert_(cond, msg)
|
||||
assert = function(cond, msg)
|
||||
if cond then return cond end
|
||||
msg = msg or "nondescript"
|
||||
print()
|
||||
print(debug.traceback(msg, 2):gsub("\n", "\r\n"))
|
||||
_error("assertion failed!")
|
||||
end
|
||||
rawset(_G, 'error', error_)
|
||||
rawset(_G, 'assert', assert_)
|
||||
|
||||
-- be strict about globals.
|
||||
local mt = getmetatable(_G)
|
||||
|
@ -34,10 +32,9 @@ local frameskip = 4
|
|||
-- true greedy epsilon has both deterministic and det_epsilon set.
|
||||
local deterministic = true -- use argmax on outputs instead of random sampling.
|
||||
local det_epsilon = true -- take random actions with probability eps.
|
||||
-- using parameters from DQN
|
||||
local eps_start = 1.0 * frameskip / 64
|
||||
local eps_stop = 0.1 * eps_start
|
||||
local eps_frames = 1000000
|
||||
local eps_frames = 2000000
|
||||
local learn_start_select = false
|
||||
--
|
||||
local epoch_trials = 40
|
||||
|
@ -714,7 +711,7 @@ local function doit(dummy)
|
|||
|
||||
local tf0 = total_frames % 1000
|
||||
local tf1 = (total_frames % 1000000 - tf0) / 1000
|
||||
local tf2 = (total_frames - tf0 - tf1) / 10000000
|
||||
local tf2 = (total_frames - tf0 - tf1) / 1000000
|
||||
gui.text(12, 212, ("%03i,%03i,%03i"):format(tf2,tf1,tf0), '#FFFFFF', '#0000003F')
|
||||
|
||||
screen_scroll_delta = screen_scroll_delta + R(0x775)
|
||||
|
@ -723,8 +720,6 @@ local function doit(dummy)
|
|||
return
|
||||
end
|
||||
|
||||
total_frames = total_frames + frameskip
|
||||
|
||||
empty(sprite_input)
|
||||
empty(tile_input)
|
||||
empty(extra_input)
|
||||
|
@ -795,6 +790,8 @@ local function doit(dummy)
|
|||
nn.reshape(tile_input, 1, tile_count)
|
||||
|
||||
if enable_network and get_state() == 'playing' or ingame_paused then
|
||||
total_frames = total_frames + frameskip
|
||||
|
||||
local choose = deterministic and argmax2 or rchoice2
|
||||
|
||||
local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input})
|
||||
|
|
3
nn.lua
3
nn.lua
|
@ -474,6 +474,7 @@ end
|
|||
|
||||
function Relu:forward(X)
|
||||
local bs = checkshape(X, self.shape_in)
|
||||
if bs ~= self.bs then self:reset_cache(bs) end
|
||||
local Y = self.cache
|
||||
|
||||
for i = 1, #X do Y[i] = X[i] >= 0 and X[i] or 0 end
|
||||
|
@ -676,8 +677,8 @@ function Embed:forward(X)
|
|||
local xi = x * self.dim
|
||||
for j = 1, self.dim do
|
||||
Y[yi+j] = self.lut[xi + j]
|
||||
yi = yi + 1
|
||||
end
|
||||
yi = yi + self.dim
|
||||
end
|
||||
|
||||
checkshape(Y, self.shape_out)
|
||||
|
|
Loading…
Reference in a new issue