tweaks and fixes
This commit is contained in:
parent
5a8c0f6140
commit
88dcd203a1
2 changed files with 8 additions and 10 deletions
15
main.lua
15
main.lua
|
@ -1,21 +1,19 @@
|
||||||
-- hacks for FCEUX being dumb.
|
-- hacks for FCEUX being dumb.
|
||||||
local _error = error
|
local _error = error
|
||||||
local _assert = assert
|
local _assert = assert
|
||||||
local function error_(msg, level)
|
error = function(msg, level)
|
||||||
if level == nil then level = 1 end
|
if level == nil then level = 1 end
|
||||||
print()
|
print()
|
||||||
print(debug.traceback(msg, 1 + level):gsub("\n", "\r\n"))
|
print(debug.traceback(msg, 1 + level):gsub("\n", "\r\n"))
|
||||||
_error(msg, level)
|
_error(msg, level)
|
||||||
end
|
end
|
||||||
local function assert_(cond, msg)
|
assert = function(cond, msg)
|
||||||
if cond then return cond end
|
if cond then return cond end
|
||||||
msg = msg or "nondescript"
|
msg = msg or "nondescript"
|
||||||
print()
|
print()
|
||||||
print(debug.traceback(msg, 2):gsub("\n", "\r\n"))
|
print(debug.traceback(msg, 2):gsub("\n", "\r\n"))
|
||||||
_error("assertion failed!")
|
_error("assertion failed!")
|
||||||
end
|
end
|
||||||
rawset(_G, 'error', error_)
|
|
||||||
rawset(_G, 'assert', assert_)
|
|
||||||
|
|
||||||
-- be strict about globals.
|
-- be strict about globals.
|
||||||
local mt = getmetatable(_G)
|
local mt = getmetatable(_G)
|
||||||
|
@ -34,10 +32,9 @@ local frameskip = 4
|
||||||
-- true greedy epsilon has both deterministic and det_epsilon set.
|
-- true greedy epsilon has both deterministic and det_epsilon set.
|
||||||
local deterministic = true -- use argmax on outputs instead of random sampling.
|
local deterministic = true -- use argmax on outputs instead of random sampling.
|
||||||
local det_epsilon = true -- take random actions with probability eps.
|
local det_epsilon = true -- take random actions with probability eps.
|
||||||
-- using parameters from DQN
|
|
||||||
local eps_start = 1.0 * frameskip / 64
|
local eps_start = 1.0 * frameskip / 64
|
||||||
local eps_stop = 0.1 * eps_start
|
local eps_stop = 0.1 * eps_start
|
||||||
local eps_frames = 1000000
|
local eps_frames = 2000000
|
||||||
local learn_start_select = false
|
local learn_start_select = false
|
||||||
--
|
--
|
||||||
local epoch_trials = 40
|
local epoch_trials = 40
|
||||||
|
@ -714,7 +711,7 @@ local function doit(dummy)
|
||||||
|
|
||||||
local tf0 = total_frames % 1000
|
local tf0 = total_frames % 1000
|
||||||
local tf1 = (total_frames % 1000000 - tf0) / 1000
|
local tf1 = (total_frames % 1000000 - tf0) / 1000
|
||||||
local tf2 = (total_frames - tf0 - tf1) / 10000000
|
local tf2 = (total_frames - tf0 - tf1) / 1000000
|
||||||
gui.text(12, 212, ("%03i,%03i,%03i"):format(tf2,tf1,tf0), '#FFFFFF', '#0000003F')
|
gui.text(12, 212, ("%03i,%03i,%03i"):format(tf2,tf1,tf0), '#FFFFFF', '#0000003F')
|
||||||
|
|
||||||
screen_scroll_delta = screen_scroll_delta + R(0x775)
|
screen_scroll_delta = screen_scroll_delta + R(0x775)
|
||||||
|
@ -723,8 +720,6 @@ local function doit(dummy)
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
total_frames = total_frames + frameskip
|
|
||||||
|
|
||||||
empty(sprite_input)
|
empty(sprite_input)
|
||||||
empty(tile_input)
|
empty(tile_input)
|
||||||
empty(extra_input)
|
empty(extra_input)
|
||||||
|
@ -795,6 +790,8 @@ local function doit(dummy)
|
||||||
nn.reshape(tile_input, 1, tile_count)
|
nn.reshape(tile_input, 1, tile_count)
|
||||||
|
|
||||||
if enable_network and get_state() == 'playing' or ingame_paused then
|
if enable_network and get_state() == 'playing' or ingame_paused then
|
||||||
|
total_frames = total_frames + frameskip
|
||||||
|
|
||||||
local choose = deterministic and argmax2 or rchoice2
|
local choose = deterministic and argmax2 or rchoice2
|
||||||
|
|
||||||
local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input})
|
local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input})
|
||||||
|
|
3
nn.lua
3
nn.lua
|
@ -474,6 +474,7 @@ end
|
||||||
|
|
||||||
function Relu:forward(X)
|
function Relu:forward(X)
|
||||||
local bs = checkshape(X, self.shape_in)
|
local bs = checkshape(X, self.shape_in)
|
||||||
|
if bs ~= self.bs then self:reset_cache(bs) end
|
||||||
local Y = self.cache
|
local Y = self.cache
|
||||||
|
|
||||||
for i = 1, #X do Y[i] = X[i] >= 0 and X[i] or 0 end
|
for i = 1, #X do Y[i] = X[i] >= 0 and X[i] or 0 end
|
||||||
|
@ -676,8 +677,8 @@ function Embed:forward(X)
|
||||||
local xi = x * self.dim
|
local xi = x * self.dim
|
||||||
for j = 1, self.dim do
|
for j = 1, self.dim do
|
||||||
Y[yi+j] = self.lut[xi + j]
|
Y[yi+j] = self.lut[xi + j]
|
||||||
yi = yi + 1
|
|
||||||
end
|
end
|
||||||
|
yi = yi + self.dim
|
||||||
end
|
end
|
||||||
|
|
||||||
checkshape(Y, self.shape_out)
|
checkshape(Y, self.shape_out)
|
||||||
|
|
Loading…
Reference in a new issue