local print = print local tostring = tostring local ipairs = ipairs local pairs = pairs local uniform = math.random local sqrt = math.sqrt local log = math.log local pi = math.pi local exp = math.exp local min = math.min local max = math.max local cos = math.cos local sin = math.sin local insert = table.insert local remove = table.remove local open = io.open local bor = bit.bor local Base = require("Base") local function contains(t, a) assert(type(t) == "table") for k, v in pairs(t) do if v == a then return true end end return false end local function prod(x, ...) if type(x) == "table" then return prod(unpack(x)) end local ret = x for i = 1, select("#", ...) do ret = ret * select(i, ...) end return ret end local function normal() -- box muller return sqrt(-2 * log(uniform() + 1e-8) + 1e-8) * cos(2 * pi * uniform()) end local function zeros(n, out) local out = out or {} for i = 1, n do out[i] = 0 end return out end local function init_zeros(t, fan_in, fan_out) for i = 1, #t do t[i] = 0 end return t end local function init_he_uniform(t, fan_in, fan_out) local s = sqrt(6 / fan_in) for i = 1, #t do t[i] = (uniform() * 2 - 1) * s end return t end local function init_he_normal(t, fan_in, fan_out) local s = sqrt(2 / fan_in) for i = 1, #t do t[i] = normal() * s end return t end local function copy(t) -- shallow copy local new_t = {} for k, v in pairs(t) do new_t[k] = v end return new_t end local function allocate(t, out, init) out = out or {} local size = t if init ~= nil then return init(zeros(size, out)) else return zeros(size, out) end end local function levelorder(field, node_in, nodes) -- horribly inefficient. nodes = nodes or {} local q = {node_in} while #q > 0 do local node = q[1] remove(q, 1) insert(nodes, node) for _, child in ipairs(node[field]) do q[#q+1] = child end end return nodes end local function traverse(node_in, node_out, nodes) nodes = nodes or {} local down = levelorder('children', node_in, {}) local up = levelorder('parents', node_out, {}) local seen = {} for _, node in ipairs(up) do seen[node] = bor(seen[node] or 0, 1) end for _, node in ipairs(down) do seen[node] = bor(seen[node] or 0, 2) if seen[node] == 3 then insert(nodes, node) end end return nodes end local function traverse_all(nodes_in, nodes_out, nodes) local all_in = {children={}} local all_out = {parents={}} for _, node in ipairs(nodes_in) do insert(all_in.children, node) end for _, node in ipairs(nodes_out) do insert(all_out.parents, node) end return traverse(all_in, all_out, nodes or {}) end local Weights = Base:extend() local Layer = Base:extend() local Model = Base:extend() local Input = Layer:extend() local Relu = Layer:extend() local Dense = Layer:extend() local Softmax = Layer:extend() function Weights:init(weight_init) self.weight_init = weight_init end function Weights:allocate(fan_in, fan_out) self.size = prod(self.shape) return allocate(self.size, self, function(t) --print('initializing weights of size', self.size, 'with fans', fan_in, fan_out) return self.weight_init(t, fan_in, fan_out) end) end local counter = {} function Layer:init(name) assert(type(name) == "string") counter[name] = (counter[name] or 0) + 1 self.name = name.."["..tostring(counter[name]).."]" self.parents = {} self.children = {} self.weights = {} --self.size_in = nil --self.size_out = nil end function Layer:make_shape(parent) if self.size_in == nil then self.size_in = parent.size_out end if self.size_out == nil then self.size_out = self.size_in end end function Layer:feed(child) assert(self.size_out ~= nil) child:make_shape(self) insert(self.children, child) insert(child.parents, self) return child end function Layer:forward() error("Unimplemented.") end function Layer:forward_deterministic(...) return self:forward(...) end function Layer:_new_weights(init) local w = Weights(init) insert(self.weights, w) return w end function Layer:get_size() local size = 0 for i, w in ipairs(self.weights) do size = size + prod(w.size) end return size end function Layer:init_weights() for i, w in ipairs(self.weights) do --print("allocating weights", i, "of", self.name) for j, v in ipairs(w) do w[j] = nil end -- FIXME: HACK w:allocate(self.size_in, self.size_out) end end function Layer:_propagate(edges, deterministic) assert(#edges == 1, #edges) -- override this if you need multiple parents. if deterministic then return self:forward_deterministic(edges[1]) else return self:forward(edges[1]) end end function Layer:propagate(values, deterministic) local edges = {} for i, parent in ipairs(self.parents) do if values[parent] ~= nil then local X = values[parent] insert(edges, X) end end assert(#edges > 0, #edges) local Y = self:_propagate(edges, deterministic) return Y end function Input:init(size) Layer.init(self, "Input") assert(type(size) == 'number') self.size_in = size self.size_out = size end function Input:forward(X) assert(#X == self.size_in) return X end function Relu:init() Layer.init(self, "Relu") end function Relu:forward(X) assert(#X == self.size_in) self.cache = self.cache or zeros(self.size_out) local Y = self.cache for i = 1, #X do Y[i] = X[i] >= 0 and X[i] or 0 end assert(#Y == self.size_out) return Y end function Dense:init(dim) Layer.init(self, "Dense") assert(type(dim) == "number") self.dim = dim self.size_out = dim self.coeffs = self:_new_weights(init_he_normal) -- should be normal, but... self.biases = self:_new_weights(init_zeros) end function Dense:make_shape(parent) self.size_in = parent.size_out self.coeffs.shape = {self.size_in, self.dim} self.biases.shape = self.dim end function Dense:forward(X) assert(#X == self.size_in) self.cache = self.cache or zeros(self.size_out) local Y = self.cache for i = 1, self.dim do local res = 0 local c = (i - 1) * #X for j = 1, #X do res = res + X[j] * self.coeffs[c + j] end Y[i] = res + self.biases[i] end assert(#Y == self.size_out) return Y end function Softmax:init() Layer.init(self, "Softmax") end function Softmax:forward(X) assert(#X == self.size_in) self.cache = self.cache or zeros(self.size_out) local Y = self.cache local alpha = 0 local num = {} -- TODO: cache local den = 0 for i = 1, #X do alpha = max(alpha, X[i]) end for i = 1, #X do num[i] = exp(X[i] - alpha) end for i = 1, #X do den = den + num[i] end for i = 1, #X do Y[i] = num[i] / den end assert(#Y == self.size_out) return Y end function Model:init(nodes_in, nodes_out) assert(#nodes_in > 0, #nodes_in) assert(#nodes_out > 0, #nodes_out) --if #nodes_in == 0 and type(nodes_in) == "table" then nodes_in = {nodes_in} end --if #nodes_out == 0 and type(nodes_out) == "table" then nodes_out = {nodes_out} end self.nodes_in = nodes_in self.nodes_out = nodes_out -- find all the used (inbetween) nodes in the graph. self.nodes = traverse_all(self.nodes_in, self.nodes_out) end function Model:reset() self.n_param = 0 for _, node in ipairs(self.nodes) do node:init_weights() self.n_param = self.n_param + node:get_size() end end function Model:forward(X) local values = {} local outputs = {} for i, node in ipairs(self.nodes) do --print(i, node.name) if contains(self.nodes_in, node) then values[node] = node:_propagate({X}) else values[node] = node:propagate(values) end if contains(self.nodes_out, node) then outputs[node] = values[node] end end return outputs end function Model:collect() -- return a flat array of all the weights in the graph. -- if Lua had slices, we wouldn't need this. future library idea? assert(self.n_param >= 0, self.n_param) local W = zeros(self.n_param) local i = 0 for _, node in ipairs(self.nodes) do for _, w in ipairs(node.weights) do for j, v in ipairs(w) do W[i+j] = v end i = i + #w end end return W end function Model:distribute(W) -- inverse operation of collect(). local i = 0 for _, node in ipairs(self.nodes) do for _, w in ipairs(node.weights) do for j, v in ipairs(w) do w[j] = W[i+j] end i = i + #w end end end function Model:save(fn) local fn = fn or 'network.txt' local f = open(fn, 'w') if f == nil then error("Failed to save network to file "..fn) end local W = self:collect() for i, v in ipairs(W) do f:write(v) f:write('\n') end f:close() end function Model:load(fn) local fn = fn or 'network.txt' local f = open(fn, 'r') if f == nil then error("Failed to load network from file "..fn) end local W = zeros(self.n_param) local i = 0 for line in f:lines() do i = i + 1 local n = tonumber(line) if n == nil then error("Failed reading line "..tostring(i).." of file "..fn) end W[i] = n end f:close() self:distribute(W) end return { uniform = uniform, normal = normal, copy = copy, zeros = zeros, init_zeros = init_zeros, init_he_uniform = init_he_uniform, init_he_normal = init_he_normal, Weights = Weights, Layer = Layer, Model = Model, Input = Input, Relu = Relu, Dense = Dense, Softmax = Softmax, }