basic learning

This commit is contained in:
Connor Olding 2017-06-29 04:51:02 +00:00
parent 75ad46bfe9
commit e2d29352c4
2 changed files with 391 additions and 184 deletions

359
main.lua
View file

@ -20,6 +20,63 @@ local function globalize(t)
end end
end end
-- configuration and globals.
math.randomseed(10)
local learning_rate = 1e-2
local deviation = 2e-2
local epoch_i = 0
local epoch_trials = 12
local base_params
local trial_i = 0
local trial_noise = {}
local trial_rewards = {}
local trials_remaining = 0
local force_start = false
local force_start_old = false
local enable_overlay = false
local enable_network = true
local startsave = savestate.create(1)
local poketime = false
local sprite_input = {}
local tile_input = {}
local reward
local powerup_old
local status_old
local coins_old
local once = false
local reset = true
local ok_routines = {
[0x4] = true, -- sliding down flagpole
[0x5] = true, -- end of level auto-walk
[0x7] = true, -- start of level auto-walk
[0x8] = true, -- normal (in control)
[0x9] = true, -- acquiring mushroom
[0xA] = true, -- losing big mario
[0xB] = true, -- uhh
[0xC] = true, -- acquiring fireflower
}
local bad_states = {
power = true,
waiting_demo = true,
playing_demo = true,
unknown = true,
lose = true,
}
local state_old = ''
-- localize some stuff. -- localize some stuff.
local print = print local print = print
@ -51,6 +108,14 @@ local arshift = bit.arshift
local rol = bit.rol local rol = bit.rol
local ror = bit.ror local ror = bit.ror
-- utilities.
local function boolean_xor(a, b)
if a and b then return false end
if not a and not b then return false end
return true
end
local function clamp(x, l, u) return min(max(x, l), u) end local function clamp(x, l, u) return min(max(x, l), u) end
local function argmax(...) local function argmax(...)
@ -70,6 +135,35 @@ local function argmax2(t)
return t[1] > t[2] return t[1] > t[2]
end end
local function empty(t)
for k, _ in pairs(t) do t[k] = nil end
return t
end
local function calc_mean_dev(x)
local mean = 0
for i, v in ipairs(x) do
mean = mean + v / #x
end
local dev = 0
for i, v in ipairs(x) do
local delta = v - mean
dev = dev + delta * delta / #x
end
return mean, dev
end
local function normalize(x, out)
out = out or x
local mean, dev = calc_mean_dev(x)
if dev <= 0 then dev = 1 end
local devs = sqrt(dev)
for i, v in ipairs(x) do out[i] = (v - mean) / devs end
return out
end
-- game-agnostic stuff (i.e. the network itself) -- game-agnostic stuff (i.e. the network itself)
package.loaded['nn'] = nil -- DEBUG package.loaded['nn'] = nil -- DEBUG
@ -100,9 +194,6 @@ end
-- and here we go with the game stuff. -- and here we go with the game stuff.
local enable_overlay = false
local enable_network = true
--[[ --[[
https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM
--]] --]]
@ -142,14 +233,17 @@ local rotation_offsets = { -- FIXME: not all of these are pixel-perfect.
-8, -38, -8, -38,
} }
local startsave = savestate.create(1) local function get_timer()
return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
end
local poketime = false local function set_timer(time)
W(0x7F8, floor(time / 100))
W(0x7F9, floor((time / 10) % 10))
W(0x7FA, floor(time % 10))
end
local sprite_input = {} local function mark_sprite(x, y, t)
local tile_input = {}
local function shitsprite(x, y, t)
if x < 0 or x >= 256 or y < 0 or y > 224 then if x < 0 or x >= 256 or y < 0 or y > 224 then
sprite_input[#sprite_input+1] = 0 sprite_input[#sprite_input+1] = 0
sprite_input[#sprite_input+1] = 0 sprite_input[#sprite_input+1] = 0
@ -168,7 +262,7 @@ local function shitsprite(x, y, t)
end end
end end
local function shittile(x, y, t) local function mark_tile(x, y, t)
tile_input[#tile_input+1] = t tile_input[#tile_input+1] = t
if t == 0 then return end if t == 0 then return end
if enable_overlay then if enable_overlay then
@ -200,23 +294,9 @@ local function getxy(i, x_addr, y_addr, pageloc_addr, hipos_addr)
return sx, sy return sx, sy
end end
local reward
local powerup_old
local status_old
local coins_old
local once = false
local reset = true
emu.poweron()
emu.unpause()
emu.speedmode("normal")
local function opermode() return R(0x770) end
local function paused() return band(R(0x776), 1) end local function paused() return band(R(0x776), 1) end
local function subroutine() return R(0xE) end
local function getstate() local function get_state()
if R(0xE) == 0xFF then return 'power' end if R(0xE) == 0xFF then return 'power' end
if R(0x774) > 0 then return 'lagging' end if R(0x774) > 0 then return 'lagging' end
if R(0x7A2) > 0 then return 'waiting_demo' end if R(0x7A2) > 0 then return 'waiting_demo' end
@ -238,35 +318,14 @@ local function getstate()
return 'unknown' return 'unknown'
end end
local ok_routines = {
[0x4] = true, -- sliding down flagpole
[0x5] = true, -- end of level auto-walk
[0x7] = true, -- start of level auto-walk
[0x8] = true, -- normal (in control)
[0x9] = true, -- acquiring mushroom
[0xA] = true, -- losing big mario
[0xB] = true, -- uhh
[0xC] = true, -- acquiring fireflower
}
local fuckstates = {
power = true,
waiting_demo = true,
playing_demo = true,
unknown = true,
lose = true,
paused = true,
}
local function advance() local function advance()
emu.frameadvance() emu.frameadvance()
while emu.lagged() do emu.frameadvance() end -- skip lag frames. while emu.lagged() do emu.frameadvance() end -- skip lag frames.
while R(0x774) > 0 do emu.frameadvance() end -- also lag frames. while R(0x774) > 0 do emu.frameadvance() end -- also lag frames.
end end
local state_old = ''
while false do while false do
local state = getstate() local state = get_state()
if state ~= state_old then if state ~= state_old then
print(emu.framecount(), state) print(emu.framecount(), state)
state_old = state state_old = state
@ -301,9 +360,9 @@ local function handle_enemies()
x, y = x + x_off, y + y_off x, y = x + x_off, y + y_off
end end
if invisible then if invisible then
shitsprite(0, 0, 0) mark_sprite(0, 0, 0)
else else
shitsprite(x, y, tid + 1) mark_sprite(x, y, tid + 1)
end end
end end
end end
@ -316,9 +375,9 @@ local function handle_fireballs()
local state = R(0x24 + i) local state = R(0x24 + i)
local invisible = state == 0 local invisible = state == 0
if invisible then if invisible then
shitsprite(0, 0, 0) mark_sprite(0, 0, 0)
else else
shitsprite(x, y, 257) mark_sprite(x, y, 257)
end end
end end
end end
@ -330,9 +389,9 @@ local function handle_blocks()
local state = R(0x26 + i) local state = R(0x26 + i)
local invisible = state == 0 local invisible = state == 0
if invisible then if invisible then
shitsprite(0, 0, 0) mark_sprite(0, 0, 0)
else else
shitsprite(x, y, 258) mark_sprite(x, y, 258)
end end
end end
end end
@ -347,9 +406,9 @@ local function handle_hammers()
if state ~= 0 if state ~= 0
and state >= 0x30 and state >= 0x30
then then
shitsprite(x, y, state + 1) mark_sprite(x, y, state + 1)
else else
shitsprite(0, 0, 0) mark_sprite(0, 0, 0)
end end
end end
end end
@ -360,9 +419,9 @@ local function handle_misc()
x, y = x + 8, y + 8 x, y = x + 8, y + 8
local state = R(0x33 + i) local state = R(0x33 + i)
if state ~= 0 then if state ~= 0 then
shitsprite(x, y, state + 1) mark_sprite(x, y, state + 1)
else else
shitsprite(0, 0, 0) mark_sprite(0, 0, 0)
end end
end end
end end
@ -383,25 +442,104 @@ local function handle_tiles()
end end
local sx = x * 16 + 8 - tile_scroll_remainder local sx = x * 16 + 8 - tile_scroll_remainder
local sy = y * 16 + 40 local sy = y * 16 + 40
shittile(sx, sy, t) mark_tile(sx, sy, t)
end end
end end
end end
local function doreset() local function learn_from_epoch()
print("resetting in state:", getstate()) print()
print('rewards:', trial_rewards)
normalize(trial_rewards)
print('normalized:', trial_rewards)
local reward_mean, reward_dev = calc_mean_dev(trial_rewards)
local step = nn.zeros(#base_params)
for i = 1, epoch_trials do
local reward = trial_rewards[i]
local noise = trial_noise[i]
for j, v in ipairs(noise) do
step[j] = step[j] + reward * v
end
end
local magnitude = learning_rate / deviation
print('stepping with magnitude', magnitude)
-- throw the division from the averaging in there too.
local altogether = magnitude / epoch_trials
for i, v in ipairs(step) do
step[i] = altogether * v
end
local step_mean, step_dev = calc_mean_dev(step)
print("step mean:", step_mean)
print("step stddev:", step_dev)
if step_dev > 1e-8 then
for i, v in ipairs(base_params) do
base_params[i] = v + step[i]
end
else
-- we didn't get anywhere. step in a random direction.
local noise = trial_noise[1]
for i, v in ipairs(base_params) do
base_params[i] = v + magnitude * noise[i]
end
end
network:distribute(base_params)
network:save()
print()
end
local function prepare_epoch()
print('preparing epoch '..tostring(epoch_i)..'. this might take a while.')
base_params = network:collect()
empty(trial_noise)
empty(trial_rewards)
for i = 1, epoch_trials do
local noise = nn.zeros(#base_params)
for j = 1, #base_params do noise[j] = nn.normal() end
trial_noise[i] = noise
end
trial_i = 0
end
local function load_next_trial()
trial_i = trial_i + 1
print('loading trial', trial_i)
local W = nn.copy(base_params)
local noise = trial_noise[trial_i]
for i, v in ipairs(base_params) do
W[i] = v + deviation * noise[i]
end
network:distribute(W)
end
local function do_reset()
print("resetting in state:", get_state())
if trial_i > 0 then trial_rewards[trial_i] = reward end
if epoch_i == 0 or trial_i == epoch_trials then
if epoch_i > 0 then learn_from_epoch() else network:reset() end
epoch_i = epoch_i + 1
prepare_epoch()
end
if once then if once then
savestate.load(startsave) savestate.load(startsave)
print("end of trial reward:", reward) print("end of trial reward:", reward)
print()
else else
savestate.save(startsave) savestate.save(startsave)
end end
once = true once = true
-- bit of a hack: -- bit of a hack:
if getstate() == 'loading' then advance() end if get_state() == 'loading' then advance() end
reward = 0 reward = 0
powerup_old = R(0x754) powerup_old = R(0x754)
status_old = R(0x756) status_old = R(0x756)
@ -413,13 +551,29 @@ local function doreset()
emu.frameadvance() -- prevents emulator from quirking up. emu.frameadvance() -- prevents emulator from quirking up.
print()
load_next_trial()
reset = false reset = false
if network ~= nil then network:reset() end -- FIXME: hack
end end
local function init()
network = make_network(279, 8)
network:reset()
print("parameters:", network.n_param)
emu.poweron()
emu.unpause()
emu.speedmode("normal")
end
init()
while true do while true do
while fuckstates[getstate()] do gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
--gui.text(120, 124, ("%02X"):format(subroutine()), '#FFFFFF', '#0000003F')
while bad_states[get_state()] do
--gui.text(120, 124, ("%02X"):format(R(0xE)), '#FFFFFF', '#0000003F')
-- mash the start button until we have control. -- mash the start button until we have control.
-- TODO: learn this too. -- TODO: learn this too.
--local jp = joypad.read(1) --local jp = joypad.read(1)
@ -437,44 +591,39 @@ while true do
reset = true reset = true
gui.text(4, 12, getstate(), '#FFFFFF', '#0000003F')
advance() advance()
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
-- bit of a hack: -- bit of a hack:
while getstate() == "loading" do advance() end while get_state() == "loading" do advance() end
state_old = getstate() state_old = get_state()
end end
if reset then doreset() end if reset then do_reset() end
if not enable_network then if not enable_network then
-- infinite time cheat. super handy for testing. -- infinite time cheat. super handy for testing.
if R(0xE) == 8 then if R(0xE) == 8 then
W(0x7F8, 9) set_timer(667)
W(0x7F9, 9)
W(0x7FA, 10)
poketime = true poketime = true
elseif poketime then elseif poketime then
poketime = false poketime = false
W(0x7F8, 0) set_timer(1)
W(0x7F9, 0)
W(0x7FA, 1)
end end
-- infinite lives. -- infinite lives.
W(0x75A, 1) W(0x75A, 1)
end end
-- empty input lists without creating a new table. empty(sprite_input)
for k, v in pairs(sprite_input) do sprite_input[k] = nil end empty(tile_input)
for k, v in pairs(tile_input) do tile_input[k] = nil end
-- player -- player
-- TODO: add check if mario is playable. -- TODO: check if mario is playable.
local x, y = getxy(0, 0x86, 0xCE, 0x6D, 0xB5) local x, y = getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
local powerup = R(0x754) local powerup = R(0x754)
local status = R(0x756) local status = R(0x756)
shitsprite(x + 8, y + 24, -powerup - 1) mark_sprite(x + 8, y + 24, -powerup - 1)
handle_enemies() handle_enemies()
handle_fireballs() handle_fireballs()
@ -484,6 +633,8 @@ while true do
handle_misc() handle_misc()
handle_tiles() handle_tiles()
local ingame_paused = get_state() == "paused"
local coins = R(0x7ED) * 10 + R(0x7EE) local coins = R(0x7ED) * 10 + R(0x7EE)
local coins_delta = coins - coins_old local coins_delta = coins - coins_old
-- handle wrap-around. -- handle wrap-around.
@ -493,45 +644,48 @@ while true do
-- 2 is fire mario. -- 2 is fire mario.
local status_delta = clamp(status - status_old, -1, 1) local status_delta = clamp(status - status_old, -1, 1)
local screen_scroll_delta = R(0x775) local screen_scroll_delta = R(0x775)
local flagpole_bonus = subroutine() == 4 and 1 or 0 local flagpole_bonus = R(0xE) == 4 and 1 or 0
local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
reward = reward + reward_delta if not ingame_paused then reward = reward + reward_delta end
--gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F') --gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F')
--gui.text(4, 22, ("%02X"):format(#tile_input), '#FFFFFF', '#0000003F') --gui.text(4, 22, ("%02X"):format(#tile_input), '#FFFFFF', '#0000003F')
gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F') gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')
gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F') gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
if getstate() == 'dead' and state_old ~= 'dead' then if get_state() == 'dead' and state_old ~= 'dead' then
print("dead. lives remaining:", R(0x75A, 0)) print("dead. lives remaining:", R(0x75A, 0))
if R(0x75A, 0) == 0 then reset = true end if R(0x75A, 0) == 0 then reset = true end
end end
if getstate() == 'lose' then if get_state() == 'lose' then
print("ran out of lives.") print("ran out of lives.")
reset = true reset = true
end end
-- lose a point for every frame paused.
if ingame_paused then reward = reward - 1 end
-- every few frames mario stands still, forcibly decrease the timer. -- every few frames mario stands still, forcibly decrease the timer.
-- this includes having the game paused.
-- TODO: more robust. doesn't detect moonwalking against a wall.
local timer = get_timer()
local timer_loser = 1/5 local timer_loser = 1/5
if math.random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then if ingame_paused or math.random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
local timer = R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
timer = clamp(timer - 1, 0, 400) timer = clamp(timer - 1, 0, 400)
W(0x7F8, floor(timer / 100)) set_timer(timer)
W(0x7F9, floor((timer / 10) % 10))
W(0x7FA, floor(timer % 10))
end end
-- if we've run out of time while the game is paused...
-- that's cheating! unpause.
force_start = ingame_paused and timer == 0
local X = {} -- TODO: cache. local X = {} -- TODO: cache.
for i, v in ipairs(sprite_input) do insert(X, v / 256) end for i, v in ipairs(sprite_input) do insert(X, v / 256) end
for i, v in ipairs(tile_input) do insert(X, v / 256) end for i, v in ipairs(tile_input) do insert(X, v / 256) end
--error(#X)
if network == nil then if enable_network and get_state() == 'playing' or ingame_paused then
network = make_network(#X, 8); network:reset()
print("parameters:", network.n_param)
end
if enable_network and getstate() == 'playing' then
local outputs = network:forward(X) local outputs = network:forward(X)
local softmaxed = { local softmaxed = {
outputs[nn_z[1]], outputs[nn_z[1]],
@ -553,12 +707,25 @@ while true do
start = argmax2(softmaxed[7]), start = argmax2(softmaxed[7]),
select = argmax2(softmaxed[8]), select = argmax2(softmaxed[8]),
} }
if force_start then
jp = {
up = false,
down = false,
left = false,
right = false,
A = false,
B = false,
start = force_start_old,
select = false,
}
end
joypad.write(1, jp) joypad.write(1, jp)
end end
coins_old = coins coins_old = coins
powerup_old = powerup powerup_old = powerup
status_old = status status_old = status
state_old = getstate() force_start_old = force_start
state_old = get_state()
advance() advance()
end end

206
nn.lua
View file

@ -13,6 +13,7 @@ local cos = math.cos
local sin = math.sin local sin = math.sin
local insert = table.insert local insert = table.insert
local remove = table.remove local remove = table.remove
local open = io.open
local bor = bit.bor local bor = bit.bor
@ -24,6 +25,17 @@ local function contains(t, a)
return false return false
end end
local function prod(x, ...)
if type(x) == "table" then
return prod(unpack(x))
end
local ret = x
for i = 1, select("#", ...) do
ret = ret * select(i, ...)
end
return ret
end
local function normal() -- box muller local function normal() -- box muller
return sqrt(-2 * log(uniform() + 1e-8) + 1e-8) * cos(2 * pi * uniform()) return sqrt(-2 * log(uniform() + 1e-8) + 1e-8) * cos(2 * pi * uniform())
end end
@ -58,11 +70,7 @@ local function copy(t) -- shallow copy
end end
local function allocate(t, out, init) local function allocate(t, out, init)
-- FIXME: this code is fucking disgusting.
out = out or {} out = out or {}
assert(type(out) == "table", type(out))
if type(t) == "number" then
local size = t local size = t
if init ~= nil then if init ~= nil then
return init(zeros(size, out)) return init(zeros(size, out))
@ -70,16 +78,45 @@ local function allocate(t, out, init)
return zeros(size, out) return zeros(size, out)
end end
end end
local topsize = t[1]
t = copy(t) local function levelorder(field, node_in, nodes)
remove(t, 1) -- horribly inefficient.
if #t == 1 then t = t[1] end nodes = nodes or {}
for i = 1, topsize do local q = {node_in}
local res = allocate(t, nil, init) while #q > 0 do
assert(res ~= nil) local node = q[1]
insert(out, res) remove(q, 1)
insert(nodes, node)
for _, child in ipairs(node[field]) do
q[#q+1] = child
end end
return out end
return nodes
end
local function traverse(node_in, node_out, nodes)
nodes = nodes or {}
local down = levelorder('children', node_in, {})
local up = levelorder('parents', node_out, {})
local seen = {}
for _, node in ipairs(up) do
seen[node] = bor(seen[node] or 0, 1)
end
for _, node in ipairs(down) do
seen[node] = bor(seen[node] or 0, 2)
if seen[node] == 3 then
insert(nodes, node)
end
end
return nodes
end
local function traverse_all(nodes_in, nodes_out, nodes)
local all_in = {children={}}
local all_out = {parents={}}
for _, node in ipairs(nodes_in) do insert(all_in.children, node) end
for _, node in ipairs(nodes_out) do insert(all_out.parents, node) end
return traverse(all_in, all_out, nodes or {})
end end
local Weights = Base:extend() local Weights = Base:extend()
@ -95,24 +132,13 @@ function Weights:init(weight_init)
end end
function Weights:allocate(fan_in, fan_out) function Weights:allocate(fan_in, fan_out)
self.size = prod(self.shape)
return allocate(self.size, self, function(t) return allocate(self.size, self, function(t)
--print('initializing weights of size', self.size, 'with fans', fan_in, fan_out) --print('initializing weights of size', self.size, 'with fans', fan_in, fan_out)
return self.weight_init(t, fan_in, fan_out) return self.weight_init(t, fan_in, fan_out)
end) end)
end end
--[[
local w = Weights(init_he_uniform)
w.size = {16, 16}
w:allocate(16, 16)
print(w)
do return end
local w = zeros(16)
for i = 1, #w do w[i] = normal() * 1920 / 2560 end
print(w)
--]]
local counter = {} local counter = {}
function Layer:init(name) function Layer:init(name)
assert(type(name) == "string") assert(type(name) == "string")
@ -152,18 +178,7 @@ function Layer:_new_weights(init)
return w return w
end end
local function prod(x, ...) function Layer:get_size()
if type(x) == "table" then
return prod(unpack(x))
end
local ret = x
for i = 1, select("#", ...) do
ret = ret * select(i, ...)
end
return ret
end
function Layer:getsize()
local size = 0 local size = 0
for i, w in ipairs(self.weights) do size = size + prod(w.size) end for i, w in ipairs(self.weights) do size = size + prod(w.size) end
return size return size
@ -237,8 +252,8 @@ end
function Dense:make_shape(parent) function Dense:make_shape(parent)
self.size_in = parent.size_out self.size_in = parent.size_out
self.coeffs.size = {self.dim, self.size_in} self.coeffs.shape = {self.size_in, self.dim}
self.biases.size = self.dim self.biases.shape = self.dim
end end
function Dense:forward(X) function Dense:forward(X)
@ -246,11 +261,11 @@ function Dense:forward(X)
self.cache = self.cache or zeros(self.size_out) self.cache = self.cache or zeros(self.size_out)
local Y = self.cache local Y = self.cache
for i = 1, #self.coeffs do for i = 1, self.dim do
local res = 0 local res = 0
local c = self.coeffs[i] local c = (i - 1) * #X
for j = 1, #X do for j = 1, #X do
res = res + X[j] * c[j] res = res + X[j] * self.coeffs[c + j]
end end
Y[i] = res + self.biases[i] Y[i] = res + self.biases[i]
end end
@ -281,46 +296,6 @@ function Softmax:forward(X)
return Y return Y
end end
local function levelorder(field, node_in, nodes)
-- horribly inefficient.
nodes = nodes or {}
local q = {node_in}
while #q > 0 do
local node = q[1]
remove(q, 1)
insert(nodes, node)
for _, child in ipairs(node[field]) do
q[#q+1] = child
end
end
return nodes
end
local function traverse(node_in, node_out, nodes)
nodes = nodes or {}
local down = levelorder('children', node_in, {})
local up = levelorder('parents', node_out, {})
local seen = {}
for _, node in ipairs(up) do
seen[node] = bor(seen[node] or 0, 1)
end
for _, node in ipairs(down) do
seen[node] = bor(seen[node] or 0, 2)
if seen[node] == 3 then
insert(nodes, node)
end
end
return nodes
end
local function traverse_all(nodes_in, nodes_out, nodes)
local all_in = {children={}}
local all_out = {parents={}}
for _, node in ipairs(nodes_in) do insert(all_in.children, node) end
for _, node in ipairs(nodes_out) do insert(all_out.parents, node) end
return traverse(all_in, all_out, nodes or {})
end
function Model:init(nodes_in, nodes_out) function Model:init(nodes_in, nodes_out)
assert(#nodes_in > 0, #nodes_in) assert(#nodes_in > 0, #nodes_in)
assert(#nodes_out > 0, #nodes_out) assert(#nodes_out > 0, #nodes_out)
@ -338,7 +313,7 @@ function Model:reset()
self.n_param = 0 self.n_param = 0
for _, node in ipairs(self.nodes) do for _, node in ipairs(self.nodes) do
node:init_weights() node:init_weights()
self.n_param = self.n_param + node:getsize() self.n_param = self.n_param + node:get_size()
end end
end end
@ -359,10 +334,75 @@ function Model:forward(X)
return outputs return outputs
end end
function Model:collect()
-- return a flat array of all the weights in the graph.
-- if Lua had slices, we wouldn't need this. future library idea?
assert(self.n_param >= 0, self.n_param)
local W = zeros(self.n_param)
local i = 0
for _, node in ipairs(self.nodes) do
for _, w in ipairs(node.weights) do
for j, v in ipairs(w) do
W[i+j] = v
end
i = i + #w
end
end
return W
end
function Model:distribute(W)
-- inverse operation of collect().
local i = 0
for _, node in ipairs(self.nodes) do
for _, w in ipairs(node.weights) do
for j, v in ipairs(w) do
w[j] = W[i+j]
end
i = i + #w
end
end
end
function Model:save(fn)
local fn = fn or 'network.txt'
local f = open(fn, 'w')
if f == nil then error("Failed to save network to file "..fn) end
local W = self:collect()
for i, v in ipairs(W) do
f:write(v)
f:write('\n')
end
f:close()
end
function Model:load(fn)
local fn = fn or 'network.txt'
local f = open(fn, 'r')
if f == nil then
error("Failed to load network from file "..fn)
end
local W = zeros(self.n_param)
local i = 0
for line in f:lines() do
i = i + 1
local n = tonumber(line)
if n == nil then
error("Failed reading line "..tostring(i).." of file "..fn)
end
W[i] = n
end
f:close()
self:distribute(W)
end
return { return {
uniform = uniform, uniform = uniform,
normal = normal, normal = normal,
copy = copy,
zeros = zeros, zeros = zeros,
init_zeros = init_zeros, init_zeros = init_zeros,
init_he_uniform = init_he_uniform, init_he_uniform = init_he_uniform,