From 29f3c278eac06a39ccb8f957f193e8f06ef4d4ff Mon Sep 17 00:00:00 2001 From: Connor Date: Wed, 28 Jun 2017 02:33:18 -0700 Subject: [PATCH] --- main.lua | 561 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ nn.lua | 377 +++++++++++++++++++++++++++++++++++++ 2 files changed, 938 insertions(+) create mode 100644 main.lua create mode 100644 nn.lua diff --git a/main.lua b/main.lua new file mode 100644 index 0000000..adf98f2 --- /dev/null +++ b/main.lua @@ -0,0 +1,561 @@ +-- be strict about globals. + +local mt = getmetatable(_G) +if mt == nil then + mt = {} + setmetatable(_G, mt) +end + +function mt.__newindex(t, n, v) + error("cannot assign undeclared global '" .. tostring(n) .. "'", 2) +end + +function mt.__index(t, n) + error("cannot use undeclared global '" .. tostring(n) .. "'", 2) +end + +local function globalize(t) + for k, v in pairs(t) do + rawset(_G, k, v) + end +end + +-- localize some stuff. + +local ipairs = ipairs +local pairs = pairs +local select = select +local abs = math.abs +local floor = math.floor +local ceil = math.ceil +local min = math.min +local max = math.max +local exp = math.exp +local sqrt = math.sqrt +local random = math.random +local insert = table.insert +local remove = table.remove +local unpack = table.unpack or unpack +local R = memory.readbyteunsigned +local S = memory.readbyte --signed +local W = memory.writebyte + +local band = bit.band +local bor = bit.bor +local bxor = bit.bxor +local bnot = bit.bnot +local lshift = bit.lshift +local rshift = bit.rshift +local arshift = bit.arshift +local rol = bit.rol +local ror = bit.ror + +local function clamp(x, l, u) return min(max(x, l), u) end + +local function argmax(...) + local max_i = 0 + local max_v = -999999999 + for i=1, select("#", ...) do + local v = select(i, ...) + if v > max_v then + max_i = i + max_v = v + end + end + return max_i +end + +local function argmax2(t) + return t[1] > t[2] +end + +-- game-agnostic stuff (i.e. the network itself) + +package.loaded['nn'] = nil -- DEBUG +local nn = require("nn") + +local network +local nn_x +local nn_y +local nn_z +local function make_network(input_size, buttons) + nn_x = nn.Input(input_size) + nn_y = nn_x + nn_z = {} + nn_y = nn_y:feed(nn.Dense(input_size)) + nn_y = nn_y:feed(nn.Relu()) + --nn_y = nn_y:feed(nn.Dense(floor(input_size / 16))) + --nn_y = nn_y:feed(nn.Relu()) + --nn_y = nn_y:feed(nn.Dense(floor(input_size / 16))) + --nn_y = nn_y:feed(nn.Relu()) + for i = 1, buttons do + nn_z[i] = nn_y + nn_z[i] = nn_z[i]:feed(nn.Dense(2)) + nn_z[i] = nn_z[i]:feed(nn.Softmax()) + end + + return nn.Model({nn_x}, nn_z) +end + +-- and here we go with the game stuff. + +local enable_overlay = false +local enable_network = true + +--[[ +https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM +--]] + +local rotation_offsets = { -- FIXME: not all of these are pixel-perfect. + 0, -40, -- 0 + 6, -38, + 15, -37, + 22, -32, + 28, -28, + 32, -22, + 37, -14, + 39, -6, + 40, 0, -- 8 + 38, 7, + 37, 15, + 33, 23, + 27, 29, + 22, 33, + 14, 37, + 6, 39, + 0, 41, -- 10 + -7, 40, + -16, 38, + -22, 34, + -28, 28, + -34, 23, + -38, 16, + -40, 8, + -40, -0, -- 18 + -40, -6, + -38, -14, + -34, -22, + -28, -28, + -22, -32, + -16, -36, + -8, -38, +} + +local startsave = savestate.create(1) + +local poketime = false + +local sprite_input = {} +local tile_input = {} + +local function shitsprite(x, y, t) + if x < 0 or x >= 256 or y < 0 or y > 224 then + sprite_input[#sprite_input+1] = 0 + sprite_input[#sprite_input+1] = 0 + sprite_input[#sprite_input+1] = 0 + else + sprite_input[#sprite_input+1] = x + sprite_input[#sprite_input+1] = y + sprite_input[#sprite_input+1] = t + end + if t == 0 then return end + if enable_overlay then + gui.box(x-4, y-4, x+4, y+4) + --gui.text(x-2, y-3, tostring(i), '#FFFFFF', '#00000000') + gui.text(x-13, y-3-9, ("%+04i"):format(t), '#FFFFFF', '#0000003F') + --gui.text(x-5, y-3+9, ("%02X"):format(x), '#FFFFFF', '#0000003F') + end +end + +local function shittile(x, y, t) + tile_input[#tile_input+1] = t + if t == 0 then return end + if enable_overlay then + gui.box(x-8, y-8, x+8, y+8) + gui.text(x-5, y-3, ("%02X"):format(t), '#FFFFFF', '#00000000') + end +end + +local function getxy(i, x_addr, y_addr, pageloc_addr, hipos_addr) + local spl_l = R(0x71A) + local spl_r = R(0x71B) + local sx_l = R(0x71C) + local sx_r = R(0x71D) + + local x = R(x_addr + i) + local y = R(y_addr + i) + local sx, sy = x, y + if pageloc_addr ~= nil then + local page = R(pageloc_addr + i) + sx = sx - sx_l - (spl_l - page) * 256 + else + sx = sx - sx_l + end + if hipos_addr ~= nil then + local hipos = S(hipos_addr + i) + sy = sy + (hipos - 1) * 256 + end + + return sx, sy +end + +local reward +local powerup_old +local status_old +local coins_old + +local once = false +local reset = true + +emu.poweron() +emu.unpause() +emu.speedmode("normal") + +local function opermode() return R(0x770) end +local function paused() return band(R(0x776), 1) end +local function subroutine() return R(0xE) end + +local function getstate() + if R(0xE) == 0xFF then return 'power' end + if R(0x774) > 0 then return 'lagging' end + if R(0x7A2) > 0 then return 'waiting_demo' end + if R(0x717) > 0 then return 'playing_demo' end +-- if R(0x770) == 0xFF then return 'power' end + if paused() ~= 0 then return 'paused' end + if R(0xE) == 0 then return 'world_screen' end +-- if R(0x712) == 1 then return 'deadmusic' end + if R(0x7CA) == 0x94 then return 'dead' end + if R(0xE) == 4 then return 'win_flagpole' end + if R(0xE) == 5 then return 'win_walking' end + if R(0xE) == 6 then return 'lose' end +-- if R(0x770) == 0 then return 'not_playing' end + if R(0x770) == 2 then return 'win_castle' end + if R(0x772) == 2 then return 'no_control' end + if R(0x772) == 3 then return 'playing' end + if R(0x770) == 1 then return 'loading' end + if R(0x770) == 3 then return 'lose' end + return 'unknown' +end + +local ok_routines = { + [0x4] = true, -- sliding down flagpole + [0x5] = true, -- end of level auto-walk + [0x7] = true, -- start of level auto-walk + [0x8] = true, -- normal (in control) + [0x9] = true, -- acquiring mushroom + [0xA] = true, -- losing big mario + [0xB] = true, -- uhh + [0xC] = true, -- acquiring fireflower +} + +local fuckstates = { + power = true, + waiting_demo = true, + playing_demo = true, + unknown = true, + lose = true, + paused = true, +} + +local function advance() + emu.frameadvance() + while emu.lagged() do emu.frameadvance() end -- skip lag frames. + while R(0x774) > 0 do emu.frameadvance() end -- also lag frames. +end + +local state_old = '' +while false do + local state = getstate() + if state ~= state_old then + print(emu.framecount(), state) + state_old = state + end + advance() +end + +local function handle_enemies() + -- enemies, flagpole + for i = 0, 5 do + local x, y = getxy(i, 0x87, 0xCF, 0x6E, 0xB6) + x, y = x + 8, y + 16 + local tid = R(0x16 + i) + local flags = R(0xF + i) + --local offscr = R(0x3D8 + i) + local invisible = tid < 0x10 and flags == 0 + if tid == 0x30 then y = y - 8 end -- flagpole flag + if tid == 0x31 then y = y - 8 end -- castle flag + if tid == 0x16 then x, y = x - 4, y - 12 end -- fireworks + if tid >= 0x24 and tid <= 0x29 then x, y = x + 16, y - 12 end -- moving platforms + if tid == 0x2D then x, y = x, y end -- bowser (TODO: determine head or body) + if tid == 0x15 then x, y = x, y - 12 end -- bowser fire + if tid == 0x32 then x, y = x, y - 8 end -- spring + -- tid == 0x35 -- toad + if tid == 0x1D or tid == 0x1B then -- rotating fire bars + x, y = x - 4, y - 12 + -- this is a mess... gotta find out its rotation and then project. + -- TODO: handle long fire bars too + local rot = R(0xA0 + i) --* 0x100 + R(0x58 + i) + gui.text(x-13, y-3+9, ("%04X"):format(rot), '#FFFFFF', '#0000003F') + local x_off, y_off = rotation_offsets[rot*2+1], rotation_offsets[rot*2+2] + x, y = x + x_off, y + y_off + end + if invisible then + shitsprite(0, 0, 0) + else + shitsprite(x, y, tid + 1) + end + end +end + +local function handle_fireballs() + -- fireballs + for i = 0, 1 do + local x, y = getxy(i, 0x8D, 0xD5, 0x74, 0xBC) + x, y = x + 4, y + 4 + local state = R(0x24 + i) + local invisible = state == 0 + if invisible then + shitsprite(0, 0, 0) + else + shitsprite(x, y, 257) + end + end +end + +local function handle_blocks() + for i = 0, 3 do + local x, y = getxy(i, 0x8F, 0xD7, 0x76, 0xBE) + x, y = x + 8, y + 8 + local state = R(0x26 + i) + local invisible = state == 0 + if invisible then + shitsprite(0, 0, 0) + else + shitsprite(x, y, 258) + end + end +end + +local function handle_hammers() + -- hammers, coins, score bonus text... + for i = 0, 8 do + local x, y = getxy(i, 0x93, 0xDB, 0x7A, 0xC2) + x, y = x + 8, y + 8 + local state = R(0x2A + i) + -- skip coin effect states. not interactable; we don't care! + if state ~= 0 + and state >= 0x30 + then + shitsprite(x, y, state + 1) + else + shitsprite(0, 0, 0) + end + end +end + +local function handle_misc() + for i = 0, 0 do + local x, y = getxy(i, 0x9C, 0xE4, 0x83, 0xCB) + x, y = x + 8, y + 8 + local state = R(0x33 + i) + if state ~= 0 then + shitsprite(x, y, state + 1) + else + shitsprite(0, 0, 0) + end + end +end + +local function handle_tiles() + --local tile_col = R(0x6A0) + local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16 + local tile_scroll_remainder = R(0x73F) % 16 + tile_input[#tile_input+1] = tile_scroll_remainder + for y = 0, 12 do + for x = 0, 16 do + local col = (x + tile_scroll) % 32 + local t + if col < 16 then + t = R(0x500 + y * 16 + (col % 16)) + else + t = R(0x5D0 + y * 16 + (col % 16)) + end + local sx = x * 16 + 8 - tile_scroll_remainder + local sy = y * 16 + 40 + shittile(sx, sy, t) + end + end +end + +while true do + while fuckstates[getstate()] do + --gui.text(120, 124, ("%02X"):format(subroutine()), '#FFFFFF', '#0000003F') + -- mash the start button until we have control. + -- TODO: learn this too. + --local jp = joypad.read(1) + local jp = { + up = false, + down = false, + left = false, + right = false, + A = false, + B = false, + select = false, + start = emu.framecount() % 2 == 1, + } + joypad.write(1, jp) + + reset = true + + gui.text(4, 12, getstate(), '#FFFFFF', '#0000003F') + advance() + + -- bit of a hack: + while getstate() == "loading" do advance() end + state_old = getstate() + end + + if reset then + print("resetting in state:", getstate()) + + if once then + savestate.load(startsave) + print("end of trial reward:", reward) + print() + else + savestate.save(startsave) + end + once = true + + -- bit of a hack: + if getstate() == 'loading' then advance() end + reward = 0 + powerup_old = R(0x754) + status_old = R(0x756) + coins_old = R(0x7ED) * 10 + R(0x7EE) + + -- set lives to 0. you only got one shot! + -- unless you get a 1-up, in which case, please continue! + W(0x75A, 0) + + emu.frameadvance() -- prevents emulator from quirking up. + + reset = false + if network ~= nil then network:reset() end -- FIXME: hack + end + + if not enable_network then + -- infinite time cheat. super handy for testing. + if R(0xE) == 8 then + W(0x7F8, 9) + W(0x7F9, 9) + W(0x7FA, 10) + poketime = true + elseif poketime then + poketime = false + W(0x7F8, 0) + W(0x7F9, 0) + W(0x7FA, 1) + end + + -- infinite lives. + W(0x75A, 1) + end + + -- empty input lists without creating a new table. + for k, v in pairs(sprite_input) do sprite_input[k] = nil end + for k, v in pairs(tile_input) do tile_input[k] = nil end + + -- player + -- TODO: add check if mario is playable. + local x, y = getxy(0, 0x86, 0xCE, 0x6D, 0xB5) + local powerup = R(0x754) + local status = R(0x756) + shitsprite(x + 8, y + 24, -powerup - 1) + + handle_enemies() + handle_fireballs() + -- blocks being hit. not interactable; we don't care! + --handle_blocks() + handle_hammers() + handle_misc() + handle_tiles() + + local coins = R(0x7ED) * 10 + R(0x7EE) + local coins_delta = coins - coins_old + -- handle wrap-around. + if coins_delta < 0 then coins_delta = 100 + coins - coins_old end + -- remember that 0 is big mario and 1 is small mario. + local powerup_delta = powerup_old - powerup + -- 2 is fire mario. + local status_delta = clamp(status - status_old, -1, 1) + local screen_scroll_delta = R(0x775) + local flagpole_bonus = subroutine() == 4 and 1 or 0 + local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus + + reward = reward + reward_delta + + --gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F') + --gui.text(4, 22, ("%02X"):format(#tile_input), '#FFFFFF', '#0000003F') + gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F') + gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F') + + if getstate() == 'dead' and state_old ~= 'dead' then + print("dead. lives remaining:", R(0x75A, 0)) + if R(0x75A, 0) == 0 then reset = true end + end + if getstate() == 'lose' then + print("ran out of lives.") + reset = true + end + + -- every few frames mario stands still, forcibly decrease the timer. + local timer_loser = 1/5 + if math.random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then + local timer = R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA) + timer = clamp(timer - 1, 0, 400) + W(0x7F8, floor(timer / 100)) + W(0x7F9, floor((timer / 10) % 10)) + W(0x7FA, floor(timer % 10)) + end + + local X = {} -- TODO: cache. + for i, v in ipairs(sprite_input) do insert(X, v / 256) end + for i, v in ipairs(tile_input) do insert(X, v / 256) end + + if network == nil then + network = make_network(#X, 8); network:reset() + print("parameters:", network.n_param) + end + + if enable_network and getstate() == 'playing' then + local outputs = network:forward(X) + local softmaxed = { + outputs[nn_z[1]], + outputs[nn_z[2]], + outputs[nn_z[3]], + outputs[nn_z[4]], + outputs[nn_z[5]], + outputs[nn_z[6]], + outputs[nn_z[7]], + outputs[nn_z[8]], + } + local jp = { + up = argmax2(softmaxed[1]), + down = argmax2(softmaxed[2]), + left = argmax2(softmaxed[3]), + right = argmax2(softmaxed[4]), + A = argmax2(softmaxed[5]), + B = argmax2(softmaxed[6]), + start = argmax2(softmaxed[7]), + select = argmax2(softmaxed[8]), + } + joypad.write(1, jp) + end + + coins_old = coins + powerup_old = powerup + status_old = status + state_old = getstate() + advance() +end diff --git a/nn.lua b/nn.lua new file mode 100644 index 0000000..26c5a49 --- /dev/null +++ b/nn.lua @@ -0,0 +1,377 @@ +local tostring = tostring +local ipairs = ipairs +local pairs = pairs +local uniform = math.random +local sqrt = math.sqrt +local log = math.log +local pi = math.pi +local exp = math.exp +local min = math.min +local max = math.max +local cos = math.cos +local sin = math.sin +local insert = table.insert +local remove = table.remove + +local bor = bit.bor + +local Base = require("Base") + +local function contains(t, a) + assert(type(t) == "table") + for k, v in pairs(t) do if v == a then return true end end + return false +end + +local function normal() + return sqrt(-2 * log(uniform() + 1e-8) + 1e-8) * cos(2 * pi * uniform()) / 2 +end + +local function zeros(n, out) + local out = out or {} + for i = 1, n do out[i] = 0 end + return out +end + +local function init_zeros(t, fan_in, fan_out) + for i = 1, #t do t[i] = 0 end + return t +end + +local function init_he_uniform(t, fan_in, fan_out) + local s = sqrt(6 / fan_in) + for i = 1, #t do t[i] = (uniform() * 2 - 1) * s end + return t +end + +local function init_he_normal(t, fan_in, fan_out) + local s = sqrt(2 / fan_in) + for i = 1, #t do t[i] = normal() * s end + return t +end + +local function copy(t) -- shallow copy + local new_t = {} + for k, v in pairs(t) do new_t[k] = v end + return new_t +end + +local function allocate(t, out, init) + -- FIXME: this code is fucking disgusting. + + out = out or {} + assert(type(out) == "table", type(out)) + if type(t) == "number" then + local size = t + if init ~= nil then + return init(zeros(size, out)) + else + return zeros(size, out) + end + end + local topsize = t[1] + t = copy(t) + remove(t, 1) + if #t == 1 then t = t[1] end + for i = 1, topsize do + local res = allocate(t, nil, init) + assert(res ~= nil) + insert(out, res) + end + return out +end + +local Weights = Base:extend() +local Layer = Base:extend() +local Model = Base:extend() +local Input = Layer:extend() +local Relu = Layer:extend() +local Dense = Layer:extend() +local Softmax = Layer:extend() + +function Weights:init(weight_init) + self.weight_init = weight_init +end + +function Weights:allocate(fan_in, fan_out) + return allocate(self.size, self, function(t) + --print('initializing weights of size', self.size, 'with fans', fan_in, fan_out) + return self.weight_init(t, fan_in, fan_out) + end) +end + +--[[ +local w = Weights(init_he_uniform) +w.size = {16, 16} +w:allocate(16, 16) +print(w) +do return end + +local w = zeros(16) +for i = 1, #w do w[i] = normal() * 1920 / 2560 end +print(w) +--]] + +local counter = {} +function Layer:init(name) + assert(type(name) == "string") + counter[name] = (counter[name] or 0) + 1 + self.name = name.."["..tostring(counter[name]).."]" + self.parents = {} + self.children = {} + self.weights = {} + --self.size_in = nil + --self.size_out = nil +end + +function Layer:make_shape(parent) + if self.size_in == nil then self.size_in = parent.size_out end + if self.size_out == nil then self.size_out = self.size_in end +end + +function Layer:feed(child) + assert(self.size_out ~= nil) + child:make_shape(self) + insert(self.children, child) + insert(child.parents, self) + return child +end + +function Layer:forward() + error("Unimplemented.") +end + +function Layer:forward_deterministic(...) + return self:forward(...) +end + +function Layer:_new_weights(init) + local w = Weights(init) + insert(self.weights, w) + return w +end + +local function prod(x, ...) + if type(x) == "table" then + return prod(unpack(x)) + end + local ret = x + for i = 1, select("#", ...) do + ret = ret * select(i, ...) + end + return ret +end + +function Layer:getsize() + local size = 0 + for i, w in ipairs(self.weights) do size = size + prod(w.size) end + return size +end + +function Layer:init_weights() + for i, w in ipairs(self.weights) do + --print("allocating weights", i, "of", self.name) + for j, v in ipairs(w) do w[j] = nil end -- FIXME: HACK + w:allocate(self.size_in, self.size_out) + end +end + +function Layer:_propagate(edges, deterministic) + assert(#edges == 1, #edges) -- override this if you need multiple parents. + if deterministic then + return self:forward_deterministic(edges[1]) + else + return self:forward(edges[1]) + end +end + +function Layer:propagate(values, deterministic) + local edges = {} + for i, parent in ipairs(self.parents) do + if values[parent] ~= nil then + local X = values[parent] + insert(edges, X) + end + end + assert(#edges > 0, #edges) + local Y = self:_propagate(edges, deterministic) + return Y +end + +function Input:init(size) + Layer.init(self, "Input") + assert(type(size) == 'number') + self.size_in = size + self.size_out = size +end + +function Input:forward(X) + assert(#X == self.size_in) + return X +end + +function Relu:init() + Layer.init(self, "Relu") +end + +function Relu:forward(X) + assert(#X == self.size_in) + self.cache = self.cache or zeros(self.size_out) + local Y = self.cache + + for i = 1, #X do Y[i] = X[i] >= 0 and X[i] or 0 end + + assert(#Y == self.size_out) + return Y +end + +function Dense:init(dim) + Layer.init(self, "Dense") + assert(type(dim) == "number") + self.dim = dim + self.size_out = dim + self.coeffs = self:_new_weights(init_he_normal) -- should be normal, but... + self.biases = self:_new_weights(init_zeros) +end + +function Dense:make_shape(parent) + self.size_in = parent.size_out + self.coeffs.size = {self.dim, self.size_in} + self.biases.size = self.dim +end + +function Dense:forward(X) + assert(#X == self.size_in) + self.cache = self.cache or zeros(self.size_out) + local Y = self.cache + + for i = 1, #self.coeffs do + local res = 0 + local c = self.coeffs[i] + for j = 1, #X do + res = res + X[j] * c[j] + end + Y[i] = res + self.biases[i] + end + + assert(#Y == self.size_out) + return Y +end + +function Softmax:init() + Layer.init(self, "Softmax") +end + +function Softmax:forward(X) + assert(#X == self.size_in) + self.cache = self.cache or zeros(self.size_out) + local Y = self.cache + + local alpha = 0 + local num = {} -- TODO: cache + local den = 0 + + for i = 1, #X do alpha = max(alpha, X[i]) end + for i = 1, #X do num[i] = exp(X[i] - alpha) end + for i = 1, #X do den = den + num[i] end + for i = 1, #X do Y[i] = num[i] / den end + + assert(#Y == self.size_out) + return Y +end + +local function levelorder(field, node_in, nodes) + -- horribly inefficient. + nodes = nodes or {} + local q = {node_in} + while #q > 0 do + local node = q[1] + remove(q, 1) + insert(nodes, node) + for _, child in ipairs(node[field]) do + q[#q+1] = child + end + end + return nodes +end + +local function traverse(node_in, node_out, nodes) + nodes = nodes or {} + local down = levelorder('children', node_in, {}) + local up = levelorder('parents', node_out, {}) + local seen = {} + for _, node in ipairs(up) do + seen[node] = bor(seen[node] or 0, 1) + end + for _, node in ipairs(down) do + seen[node] = bor(seen[node] or 0, 2) + if seen[node] == 3 then + insert(nodes, node) + end + end + return nodes +end + +local function traverse_all(nodes_in, nodes_out, nodes) + local all_in = {children={}} + local all_out = {parents={}} + for _, node in ipairs(nodes_in) do insert(all_in.children, node) end + for _, node in ipairs(nodes_out) do insert(all_out.parents, node) end + return traverse(all_in, all_out, nodes or {}) +end + +function Model:init(nodes_in, nodes_out) + assert(#nodes_in > 0, #nodes_in) + assert(#nodes_out > 0, #nodes_out) + --if #nodes_in == 0 and type(nodes_in) == "table" then nodes_in = {nodes_in} end + --if #nodes_out == 0 and type(nodes_out) == "table" then nodes_out = {nodes_out} end + + self.nodes_in = nodes_in + self.nodes_out = nodes_out + + -- find all the used (inbetween) nodes in the graph. + self.nodes = traverse_all(self.nodes_in, self.nodes_out) +end + +function Model:reset() + self.n_param = 0 + for _, node in ipairs(self.nodes) do + node:init_weights() + self.n_param = self.n_param + node:getsize() + end +end + +function Model:forward(X) + local values = {} + local outputs = {} + for i, node in ipairs(self.nodes) do + --print(i, node.name) + if contains(self.nodes_in, node) then + values[node] = node:_propagate({X}) + else + values[node] = node:propagate(values) + end + if contains(self.nodes_out, node) then + outputs[node] = values[node] + end + end + return outputs +end + +return { + uniform = uniform, + normal = normal, + + zeros = zeros, + init_zeros = init_zeros, + init_he_uniform = init_he_uniform, + init_he_normal = init_he_normal, + + Weights = Weights, + Layer = Layer, + Model = Model, + Input = Input, + Relu = Relu, + Dense = Dense, + Softmax = Softmax, +}