This commit is contained in:
commit
29f3c278ea
2 changed files with 938 additions and 0 deletions
561
main.lua
Normal file
561
main.lua
Normal file
|
@ -0,0 +1,561 @@
|
|||
-- be strict about globals.
|
||||
|
||||
local mt = getmetatable(_G)
|
||||
if mt == nil then
|
||||
mt = {}
|
||||
setmetatable(_G, mt)
|
||||
end
|
||||
|
||||
function mt.__newindex(t, n, v)
|
||||
error("cannot assign undeclared global '" .. tostring(n) .. "'", 2)
|
||||
end
|
||||
|
||||
function mt.__index(t, n)
|
||||
error("cannot use undeclared global '" .. tostring(n) .. "'", 2)
|
||||
end
|
||||
|
||||
local function globalize(t)
|
||||
for k, v in pairs(t) do
|
||||
rawset(_G, k, v)
|
||||
end
|
||||
end
|
||||
|
||||
-- localize some stuff.
|
||||
|
||||
local ipairs = ipairs
|
||||
local pairs = pairs
|
||||
local select = select
|
||||
local abs = math.abs
|
||||
local floor = math.floor
|
||||
local ceil = math.ceil
|
||||
local min = math.min
|
||||
local max = math.max
|
||||
local exp = math.exp
|
||||
local sqrt = math.sqrt
|
||||
local random = math.random
|
||||
local insert = table.insert
|
||||
local remove = table.remove
|
||||
local unpack = table.unpack or unpack
|
||||
local R = memory.readbyteunsigned
|
||||
local S = memory.readbyte --signed
|
||||
local W = memory.writebyte
|
||||
|
||||
local band = bit.band
|
||||
local bor = bit.bor
|
||||
local bxor = bit.bxor
|
||||
local bnot = bit.bnot
|
||||
local lshift = bit.lshift
|
||||
local rshift = bit.rshift
|
||||
local arshift = bit.arshift
|
||||
local rol = bit.rol
|
||||
local ror = bit.ror
|
||||
|
||||
local function clamp(x, l, u) return min(max(x, l), u) end
|
||||
|
||||
local function argmax(...)
|
||||
local max_i = 0
|
||||
local max_v = -999999999
|
||||
for i=1, select("#", ...) do
|
||||
local v = select(i, ...)
|
||||
if v > max_v then
|
||||
max_i = i
|
||||
max_v = v
|
||||
end
|
||||
end
|
||||
return max_i
|
||||
end
|
||||
|
||||
local function argmax2(t)
|
||||
return t[1] > t[2]
|
||||
end
|
||||
|
||||
-- game-agnostic stuff (i.e. the network itself)
|
||||
|
||||
package.loaded['nn'] = nil -- DEBUG
|
||||
local nn = require("nn")
|
||||
|
||||
local network
|
||||
local nn_x
|
||||
local nn_y
|
||||
local nn_z
|
||||
local function make_network(input_size, buttons)
|
||||
nn_x = nn.Input(input_size)
|
||||
nn_y = nn_x
|
||||
nn_z = {}
|
||||
nn_y = nn_y:feed(nn.Dense(input_size))
|
||||
nn_y = nn_y:feed(nn.Relu())
|
||||
--nn_y = nn_y:feed(nn.Dense(floor(input_size / 16)))
|
||||
--nn_y = nn_y:feed(nn.Relu())
|
||||
--nn_y = nn_y:feed(nn.Dense(floor(input_size / 16)))
|
||||
--nn_y = nn_y:feed(nn.Relu())
|
||||
for i = 1, buttons do
|
||||
nn_z[i] = nn_y
|
||||
nn_z[i] = nn_z[i]:feed(nn.Dense(2))
|
||||
nn_z[i] = nn_z[i]:feed(nn.Softmax())
|
||||
end
|
||||
|
||||
return nn.Model({nn_x}, nn_z)
|
||||
end
|
||||
|
||||
-- and here we go with the game stuff.
|
||||
|
||||
local enable_overlay = false
|
||||
local enable_network = true
|
||||
|
||||
--[[
|
||||
https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM
|
||||
--]]
|
||||
|
||||
local rotation_offsets = { -- FIXME: not all of these are pixel-perfect.
|
||||
0, -40, -- 0
|
||||
6, -38,
|
||||
15, -37,
|
||||
22, -32,
|
||||
28, -28,
|
||||
32, -22,
|
||||
37, -14,
|
||||
39, -6,
|
||||
40, 0, -- 8
|
||||
38, 7,
|
||||
37, 15,
|
||||
33, 23,
|
||||
27, 29,
|
||||
22, 33,
|
||||
14, 37,
|
||||
6, 39,
|
||||
0, 41, -- 10
|
||||
-7, 40,
|
||||
-16, 38,
|
||||
-22, 34,
|
||||
-28, 28,
|
||||
-34, 23,
|
||||
-38, 16,
|
||||
-40, 8,
|
||||
-40, -0, -- 18
|
||||
-40, -6,
|
||||
-38, -14,
|
||||
-34, -22,
|
||||
-28, -28,
|
||||
-22, -32,
|
||||
-16, -36,
|
||||
-8, -38,
|
||||
}
|
||||
|
||||
local startsave = savestate.create(1)
|
||||
|
||||
local poketime = false
|
||||
|
||||
local sprite_input = {}
|
||||
local tile_input = {}
|
||||
|
||||
local function shitsprite(x, y, t)
|
||||
if x < 0 or x >= 256 or y < 0 or y > 224 then
|
||||
sprite_input[#sprite_input+1] = 0
|
||||
sprite_input[#sprite_input+1] = 0
|
||||
sprite_input[#sprite_input+1] = 0
|
||||
else
|
||||
sprite_input[#sprite_input+1] = x
|
||||
sprite_input[#sprite_input+1] = y
|
||||
sprite_input[#sprite_input+1] = t
|
||||
end
|
||||
if t == 0 then return end
|
||||
if enable_overlay then
|
||||
gui.box(x-4, y-4, x+4, y+4)
|
||||
--gui.text(x-2, y-3, tostring(i), '#FFFFFF', '#00000000')
|
||||
gui.text(x-13, y-3-9, ("%+04i"):format(t), '#FFFFFF', '#0000003F')
|
||||
--gui.text(x-5, y-3+9, ("%02X"):format(x), '#FFFFFF', '#0000003F')
|
||||
end
|
||||
end
|
||||
|
||||
local function shittile(x, y, t)
|
||||
tile_input[#tile_input+1] = t
|
||||
if t == 0 then return end
|
||||
if enable_overlay then
|
||||
gui.box(x-8, y-8, x+8, y+8)
|
||||
gui.text(x-5, y-3, ("%02X"):format(t), '#FFFFFF', '#00000000')
|
||||
end
|
||||
end
|
||||
|
||||
local function getxy(i, x_addr, y_addr, pageloc_addr, hipos_addr)
|
||||
local spl_l = R(0x71A)
|
||||
local spl_r = R(0x71B)
|
||||
local sx_l = R(0x71C)
|
||||
local sx_r = R(0x71D)
|
||||
|
||||
local x = R(x_addr + i)
|
||||
local y = R(y_addr + i)
|
||||
local sx, sy = x, y
|
||||
if pageloc_addr ~= nil then
|
||||
local page = R(pageloc_addr + i)
|
||||
sx = sx - sx_l - (spl_l - page) * 256
|
||||
else
|
||||
sx = sx - sx_l
|
||||
end
|
||||
if hipos_addr ~= nil then
|
||||
local hipos = S(hipos_addr + i)
|
||||
sy = sy + (hipos - 1) * 256
|
||||
end
|
||||
|
||||
return sx, sy
|
||||
end
|
||||
|
||||
local reward
|
||||
local powerup_old
|
||||
local status_old
|
||||
local coins_old
|
||||
|
||||
local once = false
|
||||
local reset = true
|
||||
|
||||
emu.poweron()
|
||||
emu.unpause()
|
||||
emu.speedmode("normal")
|
||||
|
||||
local function opermode() return R(0x770) end
|
||||
local function paused() return band(R(0x776), 1) end
|
||||
local function subroutine() return R(0xE) end
|
||||
|
||||
local function getstate()
|
||||
if R(0xE) == 0xFF then return 'power' end
|
||||
if R(0x774) > 0 then return 'lagging' end
|
||||
if R(0x7A2) > 0 then return 'waiting_demo' end
|
||||
if R(0x717) > 0 then return 'playing_demo' end
|
||||
-- if R(0x770) == 0xFF then return 'power' end
|
||||
if paused() ~= 0 then return 'paused' end
|
||||
if R(0xE) == 0 then return 'world_screen' end
|
||||
-- if R(0x712) == 1 then return 'deadmusic' end
|
||||
if R(0x7CA) == 0x94 then return 'dead' end
|
||||
if R(0xE) == 4 then return 'win_flagpole' end
|
||||
if R(0xE) == 5 then return 'win_walking' end
|
||||
if R(0xE) == 6 then return 'lose' end
|
||||
-- if R(0x770) == 0 then return 'not_playing' end
|
||||
if R(0x770) == 2 then return 'win_castle' end
|
||||
if R(0x772) == 2 then return 'no_control' end
|
||||
if R(0x772) == 3 then return 'playing' end
|
||||
if R(0x770) == 1 then return 'loading' end
|
||||
if R(0x770) == 3 then return 'lose' end
|
||||
return 'unknown'
|
||||
end
|
||||
|
||||
local ok_routines = {
|
||||
[0x4] = true, -- sliding down flagpole
|
||||
[0x5] = true, -- end of level auto-walk
|
||||
[0x7] = true, -- start of level auto-walk
|
||||
[0x8] = true, -- normal (in control)
|
||||
[0x9] = true, -- acquiring mushroom
|
||||
[0xA] = true, -- losing big mario
|
||||
[0xB] = true, -- uhh
|
||||
[0xC] = true, -- acquiring fireflower
|
||||
}
|
||||
|
||||
local fuckstates = {
|
||||
power = true,
|
||||
waiting_demo = true,
|
||||
playing_demo = true,
|
||||
unknown = true,
|
||||
lose = true,
|
||||
paused = true,
|
||||
}
|
||||
|
||||
local function advance()
|
||||
emu.frameadvance()
|
||||
while emu.lagged() do emu.frameadvance() end -- skip lag frames.
|
||||
while R(0x774) > 0 do emu.frameadvance() end -- also lag frames.
|
||||
end
|
||||
|
||||
local state_old = ''
|
||||
while false do
|
||||
local state = getstate()
|
||||
if state ~= state_old then
|
||||
print(emu.framecount(), state)
|
||||
state_old = state
|
||||
end
|
||||
advance()
|
||||
end
|
||||
|
||||
local function handle_enemies()
|
||||
-- enemies, flagpole
|
||||
for i = 0, 5 do
|
||||
local x, y = getxy(i, 0x87, 0xCF, 0x6E, 0xB6)
|
||||
x, y = x + 8, y + 16
|
||||
local tid = R(0x16 + i)
|
||||
local flags = R(0xF + i)
|
||||
--local offscr = R(0x3D8 + i)
|
||||
local invisible = tid < 0x10 and flags == 0
|
||||
if tid == 0x30 then y = y - 8 end -- flagpole flag
|
||||
if tid == 0x31 then y = y - 8 end -- castle flag
|
||||
if tid == 0x16 then x, y = x - 4, y - 12 end -- fireworks
|
||||
if tid >= 0x24 and tid <= 0x29 then x, y = x + 16, y - 12 end -- moving platforms
|
||||
if tid == 0x2D then x, y = x, y end -- bowser (TODO: determine head or body)
|
||||
if tid == 0x15 then x, y = x, y - 12 end -- bowser fire
|
||||
if tid == 0x32 then x, y = x, y - 8 end -- spring
|
||||
-- tid == 0x35 -- toad
|
||||
if tid == 0x1D or tid == 0x1B then -- rotating fire bars
|
||||
x, y = x - 4, y - 12
|
||||
-- this is a mess... gotta find out its rotation and then project.
|
||||
-- TODO: handle long fire bars too
|
||||
local rot = R(0xA0 + i) --* 0x100 + R(0x58 + i)
|
||||
gui.text(x-13, y-3+9, ("%04X"):format(rot), '#FFFFFF', '#0000003F')
|
||||
local x_off, y_off = rotation_offsets[rot*2+1], rotation_offsets[rot*2+2]
|
||||
x, y = x + x_off, y + y_off
|
||||
end
|
||||
if invisible then
|
||||
shitsprite(0, 0, 0)
|
||||
else
|
||||
shitsprite(x, y, tid + 1)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function handle_fireballs()
|
||||
-- fireballs
|
||||
for i = 0, 1 do
|
||||
local x, y = getxy(i, 0x8D, 0xD5, 0x74, 0xBC)
|
||||
x, y = x + 4, y + 4
|
||||
local state = R(0x24 + i)
|
||||
local invisible = state == 0
|
||||
if invisible then
|
||||
shitsprite(0, 0, 0)
|
||||
else
|
||||
shitsprite(x, y, 257)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function handle_blocks()
|
||||
for i = 0, 3 do
|
||||
local x, y = getxy(i, 0x8F, 0xD7, 0x76, 0xBE)
|
||||
x, y = x + 8, y + 8
|
||||
local state = R(0x26 + i)
|
||||
local invisible = state == 0
|
||||
if invisible then
|
||||
shitsprite(0, 0, 0)
|
||||
else
|
||||
shitsprite(x, y, 258)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function handle_hammers()
|
||||
-- hammers, coins, score bonus text...
|
||||
for i = 0, 8 do
|
||||
local x, y = getxy(i, 0x93, 0xDB, 0x7A, 0xC2)
|
||||
x, y = x + 8, y + 8
|
||||
local state = R(0x2A + i)
|
||||
-- skip coin effect states. not interactable; we don't care!
|
||||
if state ~= 0
|
||||
and state >= 0x30
|
||||
then
|
||||
shitsprite(x, y, state + 1)
|
||||
else
|
||||
shitsprite(0, 0, 0)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function handle_misc()
|
||||
for i = 0, 0 do
|
||||
local x, y = getxy(i, 0x9C, 0xE4, 0x83, 0xCB)
|
||||
x, y = x + 8, y + 8
|
||||
local state = R(0x33 + i)
|
||||
if state ~= 0 then
|
||||
shitsprite(x, y, state + 1)
|
||||
else
|
||||
shitsprite(0, 0, 0)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function handle_tiles()
|
||||
--local tile_col = R(0x6A0)
|
||||
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
||||
local tile_scroll_remainder = R(0x73F) % 16
|
||||
tile_input[#tile_input+1] = tile_scroll_remainder
|
||||
for y = 0, 12 do
|
||||
for x = 0, 16 do
|
||||
local col = (x + tile_scroll) % 32
|
||||
local t
|
||||
if col < 16 then
|
||||
t = R(0x500 + y * 16 + (col % 16))
|
||||
else
|
||||
t = R(0x5D0 + y * 16 + (col % 16))
|
||||
end
|
||||
local sx = x * 16 + 8 - tile_scroll_remainder
|
||||
local sy = y * 16 + 40
|
||||
shittile(sx, sy, t)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
while true do
|
||||
while fuckstates[getstate()] do
|
||||
--gui.text(120, 124, ("%02X"):format(subroutine()), '#FFFFFF', '#0000003F')
|
||||
-- mash the start button until we have control.
|
||||
-- TODO: learn this too.
|
||||
--local jp = joypad.read(1)
|
||||
local jp = {
|
||||
up = false,
|
||||
down = false,
|
||||
left = false,
|
||||
right = false,
|
||||
A = false,
|
||||
B = false,
|
||||
select = false,
|
||||
start = emu.framecount() % 2 == 1,
|
||||
}
|
||||
joypad.write(1, jp)
|
||||
|
||||
reset = true
|
||||
|
||||
gui.text(4, 12, getstate(), '#FFFFFF', '#0000003F')
|
||||
advance()
|
||||
|
||||
-- bit of a hack:
|
||||
while getstate() == "loading" do advance() end
|
||||
state_old = getstate()
|
||||
end
|
||||
|
||||
if reset then
|
||||
print("resetting in state:", getstate())
|
||||
|
||||
if once then
|
||||
savestate.load(startsave)
|
||||
print("end of trial reward:", reward)
|
||||
print()
|
||||
else
|
||||
savestate.save(startsave)
|
||||
end
|
||||
once = true
|
||||
|
||||
-- bit of a hack:
|
||||
if getstate() == 'loading' then advance() end
|
||||
reward = 0
|
||||
powerup_old = R(0x754)
|
||||
status_old = R(0x756)
|
||||
coins_old = R(0x7ED) * 10 + R(0x7EE)
|
||||
|
||||
-- set lives to 0. you only got one shot!
|
||||
-- unless you get a 1-up, in which case, please continue!
|
||||
W(0x75A, 0)
|
||||
|
||||
emu.frameadvance() -- prevents emulator from quirking up.
|
||||
|
||||
reset = false
|
||||
if network ~= nil then network:reset() end -- FIXME: hack
|
||||
end
|
||||
|
||||
if not enable_network then
|
||||
-- infinite time cheat. super handy for testing.
|
||||
if R(0xE) == 8 then
|
||||
W(0x7F8, 9)
|
||||
W(0x7F9, 9)
|
||||
W(0x7FA, 10)
|
||||
poketime = true
|
||||
elseif poketime then
|
||||
poketime = false
|
||||
W(0x7F8, 0)
|
||||
W(0x7F9, 0)
|
||||
W(0x7FA, 1)
|
||||
end
|
||||
|
||||
-- infinite lives.
|
||||
W(0x75A, 1)
|
||||
end
|
||||
|
||||
-- empty input lists without creating a new table.
|
||||
for k, v in pairs(sprite_input) do sprite_input[k] = nil end
|
||||
for k, v in pairs(tile_input) do tile_input[k] = nil end
|
||||
|
||||
-- player
|
||||
-- TODO: add check if mario is playable.
|
||||
local x, y = getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
|
||||
local powerup = R(0x754)
|
||||
local status = R(0x756)
|
||||
shitsprite(x + 8, y + 24, -powerup - 1)
|
||||
|
||||
handle_enemies()
|
||||
handle_fireballs()
|
||||
-- blocks being hit. not interactable; we don't care!
|
||||
--handle_blocks()
|
||||
handle_hammers()
|
||||
handle_misc()
|
||||
handle_tiles()
|
||||
|
||||
local coins = R(0x7ED) * 10 + R(0x7EE)
|
||||
local coins_delta = coins - coins_old
|
||||
-- handle wrap-around.
|
||||
if coins_delta < 0 then coins_delta = 100 + coins - coins_old end
|
||||
-- remember that 0 is big mario and 1 is small mario.
|
||||
local powerup_delta = powerup_old - powerup
|
||||
-- 2 is fire mario.
|
||||
local status_delta = clamp(status - status_old, -1, 1)
|
||||
local screen_scroll_delta = R(0x775)
|
||||
local flagpole_bonus = subroutine() == 4 and 1 or 0
|
||||
local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
|
||||
|
||||
reward = reward + reward_delta
|
||||
|
||||
--gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F')
|
||||
--gui.text(4, 22, ("%02X"):format(#tile_input), '#FFFFFF', '#0000003F')
|
||||
gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')
|
||||
gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
|
||||
|
||||
if getstate() == 'dead' and state_old ~= 'dead' then
|
||||
print("dead. lives remaining:", R(0x75A, 0))
|
||||
if R(0x75A, 0) == 0 then reset = true end
|
||||
end
|
||||
if getstate() == 'lose' then
|
||||
print("ran out of lives.")
|
||||
reset = true
|
||||
end
|
||||
|
||||
-- every few frames mario stands still, forcibly decrease the timer.
|
||||
local timer_loser = 1/5
|
||||
if math.random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
|
||||
local timer = R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
|
||||
timer = clamp(timer - 1, 0, 400)
|
||||
W(0x7F8, floor(timer / 100))
|
||||
W(0x7F9, floor((timer / 10) % 10))
|
||||
W(0x7FA, floor(timer % 10))
|
||||
end
|
||||
|
||||
local X = {} -- TODO: cache.
|
||||
for i, v in ipairs(sprite_input) do insert(X, v / 256) end
|
||||
for i, v in ipairs(tile_input) do insert(X, v / 256) end
|
||||
|
||||
if network == nil then
|
||||
network = make_network(#X, 8); network:reset()
|
||||
print("parameters:", network.n_param)
|
||||
end
|
||||
|
||||
if enable_network and getstate() == 'playing' then
|
||||
local outputs = network:forward(X)
|
||||
local softmaxed = {
|
||||
outputs[nn_z[1]],
|
||||
outputs[nn_z[2]],
|
||||
outputs[nn_z[3]],
|
||||
outputs[nn_z[4]],
|
||||
outputs[nn_z[5]],
|
||||
outputs[nn_z[6]],
|
||||
outputs[nn_z[7]],
|
||||
outputs[nn_z[8]],
|
||||
}
|
||||
local jp = {
|
||||
up = argmax2(softmaxed[1]),
|
||||
down = argmax2(softmaxed[2]),
|
||||
left = argmax2(softmaxed[3]),
|
||||
right = argmax2(softmaxed[4]),
|
||||
A = argmax2(softmaxed[5]),
|
||||
B = argmax2(softmaxed[6]),
|
||||
start = argmax2(softmaxed[7]),
|
||||
select = argmax2(softmaxed[8]),
|
||||
}
|
||||
joypad.write(1, jp)
|
||||
end
|
||||
|
||||
coins_old = coins
|
||||
powerup_old = powerup
|
||||
status_old = status
|
||||
state_old = getstate()
|
||||
advance()
|
||||
end
|
377
nn.lua
Normal file
377
nn.lua
Normal file
|
@ -0,0 +1,377 @@
|
|||
local tostring = tostring
|
||||
local ipairs = ipairs
|
||||
local pairs = pairs
|
||||
local uniform = math.random
|
||||
local sqrt = math.sqrt
|
||||
local log = math.log
|
||||
local pi = math.pi
|
||||
local exp = math.exp
|
||||
local min = math.min
|
||||
local max = math.max
|
||||
local cos = math.cos
|
||||
local sin = math.sin
|
||||
local insert = table.insert
|
||||
local remove = table.remove
|
||||
|
||||
local bor = bit.bor
|
||||
|
||||
local Base = require("Base")
|
||||
|
||||
local function contains(t, a)
|
||||
assert(type(t) == "table")
|
||||
for k, v in pairs(t) do if v == a then return true end end
|
||||
return false
|
||||
end
|
||||
|
||||
local function normal()
|
||||
return sqrt(-2 * log(uniform() + 1e-8) + 1e-8) * cos(2 * pi * uniform()) / 2
|
||||
end
|
||||
|
||||
local function zeros(n, out)
|
||||
local out = out or {}
|
||||
for i = 1, n do out[i] = 0 end
|
||||
return out
|
||||
end
|
||||
|
||||
local function init_zeros(t, fan_in, fan_out)
|
||||
for i = 1, #t do t[i] = 0 end
|
||||
return t
|
||||
end
|
||||
|
||||
local function init_he_uniform(t, fan_in, fan_out)
|
||||
local s = sqrt(6 / fan_in)
|
||||
for i = 1, #t do t[i] = (uniform() * 2 - 1) * s end
|
||||
return t
|
||||
end
|
||||
|
||||
local function init_he_normal(t, fan_in, fan_out)
|
||||
local s = sqrt(2 / fan_in)
|
||||
for i = 1, #t do t[i] = normal() * s end
|
||||
return t
|
||||
end
|
||||
|
||||
local function copy(t) -- shallow copy
|
||||
local new_t = {}
|
||||
for k, v in pairs(t) do new_t[k] = v end
|
||||
return new_t
|
||||
end
|
||||
|
||||
local function allocate(t, out, init)
|
||||
-- FIXME: this code is fucking disgusting.
|
||||
|
||||
out = out or {}
|
||||
assert(type(out) == "table", type(out))
|
||||
if type(t) == "number" then
|
||||
local size = t
|
||||
if init ~= nil then
|
||||
return init(zeros(size, out))
|
||||
else
|
||||
return zeros(size, out)
|
||||
end
|
||||
end
|
||||
local topsize = t[1]
|
||||
t = copy(t)
|
||||
remove(t, 1)
|
||||
if #t == 1 then t = t[1] end
|
||||
for i = 1, topsize do
|
||||
local res = allocate(t, nil, init)
|
||||
assert(res ~= nil)
|
||||
insert(out, res)
|
||||
end
|
||||
return out
|
||||
end
|
||||
|
||||
local Weights = Base:extend()
|
||||
local Layer = Base:extend()
|
||||
local Model = Base:extend()
|
||||
local Input = Layer:extend()
|
||||
local Relu = Layer:extend()
|
||||
local Dense = Layer:extend()
|
||||
local Softmax = Layer:extend()
|
||||
|
||||
function Weights:init(weight_init)
|
||||
self.weight_init = weight_init
|
||||
end
|
||||
|
||||
function Weights:allocate(fan_in, fan_out)
|
||||
return allocate(self.size, self, function(t)
|
||||
--print('initializing weights of size', self.size, 'with fans', fan_in, fan_out)
|
||||
return self.weight_init(t, fan_in, fan_out)
|
||||
end)
|
||||
end
|
||||
|
||||
--[[
|
||||
local w = Weights(init_he_uniform)
|
||||
w.size = {16, 16}
|
||||
w:allocate(16, 16)
|
||||
print(w)
|
||||
do return end
|
||||
|
||||
local w = zeros(16)
|
||||
for i = 1, #w do w[i] = normal() * 1920 / 2560 end
|
||||
print(w)
|
||||
--]]
|
||||
|
||||
local counter = {}
|
||||
function Layer:init(name)
|
||||
assert(type(name) == "string")
|
||||
counter[name] = (counter[name] or 0) + 1
|
||||
self.name = name.."["..tostring(counter[name]).."]"
|
||||
self.parents = {}
|
||||
self.children = {}
|
||||
self.weights = {}
|
||||
--self.size_in = nil
|
||||
--self.size_out = nil
|
||||
end
|
||||
|
||||
function Layer:make_shape(parent)
|
||||
if self.size_in == nil then self.size_in = parent.size_out end
|
||||
if self.size_out == nil then self.size_out = self.size_in end
|
||||
end
|
||||
|
||||
function Layer:feed(child)
|
||||
assert(self.size_out ~= nil)
|
||||
child:make_shape(self)
|
||||
insert(self.children, child)
|
||||
insert(child.parents, self)
|
||||
return child
|
||||
end
|
||||
|
||||
function Layer:forward()
|
||||
error("Unimplemented.")
|
||||
end
|
||||
|
||||
function Layer:forward_deterministic(...)
|
||||
return self:forward(...)
|
||||
end
|
||||
|
||||
function Layer:_new_weights(init)
|
||||
local w = Weights(init)
|
||||
insert(self.weights, w)
|
||||
return w
|
||||
end
|
||||
|
||||
local function prod(x, ...)
|
||||
if type(x) == "table" then
|
||||
return prod(unpack(x))
|
||||
end
|
||||
local ret = x
|
||||
for i = 1, select("#", ...) do
|
||||
ret = ret * select(i, ...)
|
||||
end
|
||||
return ret
|
||||
end
|
||||
|
||||
function Layer:getsize()
|
||||
local size = 0
|
||||
for i, w in ipairs(self.weights) do size = size + prod(w.size) end
|
||||
return size
|
||||
end
|
||||
|
||||
function Layer:init_weights()
|
||||
for i, w in ipairs(self.weights) do
|
||||
--print("allocating weights", i, "of", self.name)
|
||||
for j, v in ipairs(w) do w[j] = nil end -- FIXME: HACK
|
||||
w:allocate(self.size_in, self.size_out)
|
||||
end
|
||||
end
|
||||
|
||||
function Layer:_propagate(edges, deterministic)
|
||||
assert(#edges == 1, #edges) -- override this if you need multiple parents.
|
||||
if deterministic then
|
||||
return self:forward_deterministic(edges[1])
|
||||
else
|
||||
return self:forward(edges[1])
|
||||
end
|
||||
end
|
||||
|
||||
function Layer:propagate(values, deterministic)
|
||||
local edges = {}
|
||||
for i, parent in ipairs(self.parents) do
|
||||
if values[parent] ~= nil then
|
||||
local X = values[parent]
|
||||
insert(edges, X)
|
||||
end
|
||||
end
|
||||
assert(#edges > 0, #edges)
|
||||
local Y = self:_propagate(edges, deterministic)
|
||||
return Y
|
||||
end
|
||||
|
||||
function Input:init(size)
|
||||
Layer.init(self, "Input")
|
||||
assert(type(size) == 'number')
|
||||
self.size_in = size
|
||||
self.size_out = size
|
||||
end
|
||||
|
||||
function Input:forward(X)
|
||||
assert(#X == self.size_in)
|
||||
return X
|
||||
end
|
||||
|
||||
function Relu:init()
|
||||
Layer.init(self, "Relu")
|
||||
end
|
||||
|
||||
function Relu:forward(X)
|
||||
assert(#X == self.size_in)
|
||||
self.cache = self.cache or zeros(self.size_out)
|
||||
local Y = self.cache
|
||||
|
||||
for i = 1, #X do Y[i] = X[i] >= 0 and X[i] or 0 end
|
||||
|
||||
assert(#Y == self.size_out)
|
||||
return Y
|
||||
end
|
||||
|
||||
function Dense:init(dim)
|
||||
Layer.init(self, "Dense")
|
||||
assert(type(dim) == "number")
|
||||
self.dim = dim
|
||||
self.size_out = dim
|
||||
self.coeffs = self:_new_weights(init_he_normal) -- should be normal, but...
|
||||
self.biases = self:_new_weights(init_zeros)
|
||||
end
|
||||
|
||||
function Dense:make_shape(parent)
|
||||
self.size_in = parent.size_out
|
||||
self.coeffs.size = {self.dim, self.size_in}
|
||||
self.biases.size = self.dim
|
||||
end
|
||||
|
||||
function Dense:forward(X)
|
||||
assert(#X == self.size_in)
|
||||
self.cache = self.cache or zeros(self.size_out)
|
||||
local Y = self.cache
|
||||
|
||||
for i = 1, #self.coeffs do
|
||||
local res = 0
|
||||
local c = self.coeffs[i]
|
||||
for j = 1, #X do
|
||||
res = res + X[j] * c[j]
|
||||
end
|
||||
Y[i] = res + self.biases[i]
|
||||
end
|
||||
|
||||
assert(#Y == self.size_out)
|
||||
return Y
|
||||
end
|
||||
|
||||
function Softmax:init()
|
||||
Layer.init(self, "Softmax")
|
||||
end
|
||||
|
||||
function Softmax:forward(X)
|
||||
assert(#X == self.size_in)
|
||||
self.cache = self.cache or zeros(self.size_out)
|
||||
local Y = self.cache
|
||||
|
||||
local alpha = 0
|
||||
local num = {} -- TODO: cache
|
||||
local den = 0
|
||||
|
||||
for i = 1, #X do alpha = max(alpha, X[i]) end
|
||||
for i = 1, #X do num[i] = exp(X[i] - alpha) end
|
||||
for i = 1, #X do den = den + num[i] end
|
||||
for i = 1, #X do Y[i] = num[i] / den end
|
||||
|
||||
assert(#Y == self.size_out)
|
||||
return Y
|
||||
end
|
||||
|
||||
local function levelorder(field, node_in, nodes)
|
||||
-- horribly inefficient.
|
||||
nodes = nodes or {}
|
||||
local q = {node_in}
|
||||
while #q > 0 do
|
||||
local node = q[1]
|
||||
remove(q, 1)
|
||||
insert(nodes, node)
|
||||
for _, child in ipairs(node[field]) do
|
||||
q[#q+1] = child
|
||||
end
|
||||
end
|
||||
return nodes
|
||||
end
|
||||
|
||||
local function traverse(node_in, node_out, nodes)
|
||||
nodes = nodes or {}
|
||||
local down = levelorder('children', node_in, {})
|
||||
local up = levelorder('parents', node_out, {})
|
||||
local seen = {}
|
||||
for _, node in ipairs(up) do
|
||||
seen[node] = bor(seen[node] or 0, 1)
|
||||
end
|
||||
for _, node in ipairs(down) do
|
||||
seen[node] = bor(seen[node] or 0, 2)
|
||||
if seen[node] == 3 then
|
||||
insert(nodes, node)
|
||||
end
|
||||
end
|
||||
return nodes
|
||||
end
|
||||
|
||||
local function traverse_all(nodes_in, nodes_out, nodes)
|
||||
local all_in = {children={}}
|
||||
local all_out = {parents={}}
|
||||
for _, node in ipairs(nodes_in) do insert(all_in.children, node) end
|
||||
for _, node in ipairs(nodes_out) do insert(all_out.parents, node) end
|
||||
return traverse(all_in, all_out, nodes or {})
|
||||
end
|
||||
|
||||
function Model:init(nodes_in, nodes_out)
|
||||
assert(#nodes_in > 0, #nodes_in)
|
||||
assert(#nodes_out > 0, #nodes_out)
|
||||
--if #nodes_in == 0 and type(nodes_in) == "table" then nodes_in = {nodes_in} end
|
||||
--if #nodes_out == 0 and type(nodes_out) == "table" then nodes_out = {nodes_out} end
|
||||
|
||||
self.nodes_in = nodes_in
|
||||
self.nodes_out = nodes_out
|
||||
|
||||
-- find all the used (inbetween) nodes in the graph.
|
||||
self.nodes = traverse_all(self.nodes_in, self.nodes_out)
|
||||
end
|
||||
|
||||
function Model:reset()
|
||||
self.n_param = 0
|
||||
for _, node in ipairs(self.nodes) do
|
||||
node:init_weights()
|
||||
self.n_param = self.n_param + node:getsize()
|
||||
end
|
||||
end
|
||||
|
||||
function Model:forward(X)
|
||||
local values = {}
|
||||
local outputs = {}
|
||||
for i, node in ipairs(self.nodes) do
|
||||
--print(i, node.name)
|
||||
if contains(self.nodes_in, node) then
|
||||
values[node] = node:_propagate({X})
|
||||
else
|
||||
values[node] = node:propagate(values)
|
||||
end
|
||||
if contains(self.nodes_out, node) then
|
||||
outputs[node] = values[node]
|
||||
end
|
||||
end
|
||||
return outputs
|
||||
end
|
||||
|
||||
return {
|
||||
uniform = uniform,
|
||||
normal = normal,
|
||||
|
||||
zeros = zeros,
|
||||
init_zeros = init_zeros,
|
||||
init_he_uniform = init_he_uniform,
|
||||
init_he_normal = init_he_normal,
|
||||
|
||||
Weights = Weights,
|
||||
Layer = Layer,
|
||||
Model = Model,
|
||||
Input = Input,
|
||||
Relu = Relu,
|
||||
Dense = Dense,
|
||||
Softmax = Softmax,
|
||||
}
|
Loading…
Reference in a new issue