local assert = assert local ceil = math.ceil local cos = math.cos local exp = math.exp local floor = math.floor local huge = math.huge local insert = table.insert local ipairs = ipairs local log = math.log local max = math.max local min = math.min local open = io.open local pairs = pairs local pi = math.pi local print = print local remove = table.remove local sin = math.sin local sqrt = math.sqrt local tostring = tostring local uniform = math.random local unpack = table.unpack or unpack local Base = require("Base") local util = require("util") -- hacks local function helpme() print(debug.traceback('helpme', 2):gsub("\n", "\r\n")) end -- math utilities local function prod(x, ...) if type(x) == "table" then return prod(unpack(x)) end local ret = x for i = 1, select("#", ...) do ret = ret * select(i, ...) end return ret end local function normal() -- box muller return sqrt(-2 * log(uniform() + 1e-8) + 1e-8) * cos(2 * pi * uniform()) end local function zeros(n, out) out = out or {} if type(n) == 'table' then local shape = n n = prod(shape) out.shape = shape end for i = 1, n do out[i] = 0 end return out end local function arange(n, out) out = out or {} if type(n) == 'table' then local shape = n n = prod(shape) out.shape = shape end for i = 1, n do out[i] = i - 1 end return out end local function allocate(t, out, init) out = out or {} local size = t if init ~= nil then return init(zeros(size, out)) else return zeros(size, out) end end local function init_zeros(t, fan_in, fan_out) for i = 1, #t do t[i] = 0 end return t end local function init_uniform(t, fan_in, fan_out) for i = 1, #t do t[i] = uniform() * 2 - 1 end return t end local function init_he_uniform(t, fan_in, fan_out) local s = sqrt(6 / fan_in) for i = 1, #t do t[i] = (uniform() * 2 - 1) * s end return t end local function init_he_normal(t, fan_in, fan_out) local s = sqrt(2 / fan_in) for i = 1, #t do t[i] = normal() * s end return t end -- ndarray-ish stuff and more involved math local function pp(t, fmt, sep, ti, di, depth, isfirst, islast) -- pretty-prints an nd-array. fmt = fmt or '%10.7f,' sep = sep or ',' ti = ti or 0 di = di or 1 depth = depth or 0 if t.shape == nil then local s = '[' for i = 1, #t do s = s..fmt:format(t[i]) end return s..']'..sep..'\n' end local dim = t.shape[di] local ti_step = 1 for dj = di + 1, #t.shape do ti_step = ti_step * t.shape[dj] end local indent = '' for i = 1, depth do indent = indent..' ' end local s = '' if di ~= #t.shape then if isfirst then s = s..indent..'[\n' else s = s..'[\n' end for i = 1, dim do s = s..pp(t, fmt, sep, ti, di + 1, depth + 1, i == 1, i == dim) ti = ti + ti_step end if islast then s = s..indent..']'..sep..'\n' else s = s..indent..']'..sep end else s = s..indent..'[' for i = ti + 1, ti + dim do s = s..fmt:format(t[i])..sep end s = s..']'..sep..'\n' end return s end local function ppi(t, n, ...) -- TODO: determine maximum number of digits if n is omitted. n = n or 1 return pp(t, '%'..tostring(n)..'i', ' ', ...) end local function checkshape_helper(shape, isbatch) local s = '{ ' if not isbatch then s = s..'n, ' end for i, v in ipairs(shape) do if not isbatch or i > 1 then s = s..tostring(v)..(i ~= #shape and ', ' or ' ') end end return s..'}' end local function checkshape(batch, shape) assert(type(batch) == 'table', "batch is not an array") assert(batch.shape ~= nil, "batch is missing a shape") if #batch.shape == 1 then error("batch shape is incomplete", 2) end for n=1, #shape do if batch.shape[n+1] ~= shape[n] then local s1 = checkshape_helper(batch.shape, true) local s2 = checkshape_helper(shape, false) error("shapes do not match: "..s1.." ~= "..s2, 2) end end return batch.shape[1] end local function reshape(a, ...) local new_shape = {...} assert(#a == prod(new_shape), "new shape does not fit size") a.shape = new_shape return a end local function cache(bs, shape) if bs == nil then return nil end local fullshape = util.copy(shape) insert(fullshape, bs, 1) return zeros(fullshape) end local function dot(a, b, ax_a, ax_b, out) ax_a = ax_a or #a.shape - 0 ax_b = ax_b or #b.shape - 1 assert(a.shape[ax_a] == b.shape[ax_b], "dotted axes do not match") local dim = a.shape[ax_a] local out_shape = {} for di = 1, #a.shape do if di ~= ax_a then insert(out_shape, a.shape[di]) end end for di = 1, #b.shape do if di ~= ax_b then insert(out_shape, b.shape[di]) end end if out == nil then out = zeros(prod(out_shape)) else assert(prod(out_shape) == #out, "given output is the wrong size") end out.shape = out_shape local a0 = 1 local a1 = 1 local b0 = 1 local b1 = 1 for di = 1, ax_a - 1 do a0 = a0 * a.shape[di] end for di = 1, ax_b - 1 do b0 = b0 * b.shape[di] end for di = ax_a + 1, #a.shape do a1 = a1 * a.shape[di] end for di = ax_b + 1, #b.shape do b1 = b1 * b.shape[di] end local o = 1 local i_end = a0 * dim - 1 local k_end = b0 * dim - 1 for i = 0, i_end, dim do for j = 1, a1 do for k = 0, k_end, dim do for m = 1, b1 do local res = 0 local x = i + j local y = k + m for d = 1, dim do res = res + a[x] * b[y] x = x + a1 y = y + b1 end out[o] = res o = o + 1 end end end end return out end -- nodal local function traverse(node_in, node_out, nodes, dummy_mode) -- i have no idea if this is any algorithm in particular. nodes = nodes or {} local seen_up = {} local q = {node_out} while #q > 0 do local node = remove(q, 1) seen_up[node] = true for _, parent in ipairs(node.parents) do insert(q, parent) end end if dummy_mode then seen_up[node_in] = true end nodes = {} q = {node_in} while #q > 0 do local node = remove(q, 1) if seen_up[node] then local all_parents_added = true for _, parent in ipairs(node.parents) do if not util.contains(nodes, parent) then all_parents_added = false break end end if not util.contains(nodes, node) and all_parents_added then insert(nodes, node) end for _, child in ipairs(node.children) do insert(q, child) end end end if dummy_mode then remove(nodes, util.indexof(nodes, node_in)) end return nodes end local function traverse_all(nodes_in, nodes_out, nodes) local all_in = {children={}, parents={}} local all_out = {children={}, parents={}} for _, node in ipairs(nodes_in) do insert(all_in.children, node) end for _, node in ipairs(nodes_out) do insert(all_out.parents, node) end return traverse(all_in, all_out, nodes or {}, true) end -- classes local Weights = Base:extend() local Layer = Base:extend() local Model = Base:extend() local Input = Layer:extend() local Merge = Layer:extend() local Relu = Layer:extend() local Gelu = Layer:extend() local Cos = Layer:extend() local Dense = Layer:extend() local Softmax = Layer:extend() local Embed = Layer:extend() local LayerNorm = Layer:extend() function Weights:init(weight_init) self.weight_init = weight_init end function Weights:allocate(fan_in, fan_out) self.size = prod(self.shape) return allocate(self.size, self, function(t) --print('initializing weights of size', self.size, 'with fans', fan_in, fan_out) return self.weight_init(t, fan_in, fan_out) end) end local counter = {} function Layer:init(name) assert(type(name) == "string") counter[name] = (counter[name] or 0) + 1 self.name = name.."["..tostring(counter[name]).."]" self.parents = {} self.children = {} self.weights = {} --self.shape_in = nil --self.shape_out = nil end function Layer:make_shape(parent) if self.shape_in == nil then self.shape_in = parent.shape_out end if self.shape_out == nil then self.shape_out = self.shape_in end end function Layer:feed(child) assert(self.shape_out ~= nil, "missing output shape: "..self.name) child:make_shape(self) insert(self.children, child) insert(child.parents, self) return child end function Layer:forward() error("Unimplemented.") end function Layer:forward_deterministic(...) return self:forward(...) end function Layer:_new_weights(init) local w = Weights(init) insert(self.weights, w) return w end function Layer:get_size() local size = 0 for i, w in ipairs(self.weights) do size = size + prod(w.shape) end return size end function Layer:init_weights() for i, w in ipairs(self.weights) do --print("allocating weights", i, "of", self.name) for j, v in ipairs(w) do w[j] = nil end -- FIXME: HACK w:allocate(prod(self.shape_in), prod(self.shape_out)) end self:reset_cache() end function Layer:reset_cache(bs) self.bs = bs end function Layer:_propagate(edges, deterministic) -- override this if you need multiple parents. assert(#edges == 1, ("%s edges for node %s (expected 1)"):format(#edges, self.name)) if deterministic then return self:forward_deterministic(edges[1]) else return self:forward(edges[1]) end end function Layer:propagate(values, deterministic) local edges = {} for i, parent in ipairs(self.parents) do if values[parent] ~= nil then local X = values[parent] insert(edges, X) end end assert(#edges > 0, ("%s edges for node %s (expected >0)"):format(#edges, self.name)) local Y = self:_propagate(edges, deterministic) return Y end function Input:init(shape) Layer.init(self, "Input") assert(type(shape) == 'table') self.shape_in = shape self.shape_out = shape end function Input:forward(X) checkshape(X, self.shape_in) return X end function Merge:init() Layer.init(self, "Merge") self.size = 0 self.shape_in = 0 end function Merge:make_shape(parent) self.size = self.size + prod(parent.shape_out) self.shape_in = self.shape_in + 1 -- TODO: more robust. self.shape_out = {self.size} end function Merge:reset_cache(bs) self.bs = bs self.cache = cache(bs, self.shape_out) end function Merge:_propagate(edges, deterministic) assert(#edges == self.shape_in) local bs = edges[1].shape[1] if bs ~= self.bs then self:reset_cache(bs) end local Y = self.cache local yi = 1 for i, X in ipairs(edges) do for _, x in ipairs(X) do Y[yi] = x yi = yi + 1 end end checkshape(Y, self.shape_out) return Y end function Relu:init() Layer.init(self, "Relu") end function Relu:reset_cache(bs) self.bs = bs self.cache = cache(bs, self.shape_out) self.dcache = cache(bs, self.shape_in) end function Relu:forward(X) local bs = checkshape(X, self.shape_in) if bs ~= self.bs then self:reset_cache(bs) end local Y = self.cache for i = 1, #X do Y[i] = X[i] >= 0 and X[i] or 0 end checkshape(Y, self.shape_out) return Y end function Gelu:init() Layer.init(self, "Gelu") end function Gelu:reset_cache(bs) self.bs = bs self.cache = cache(bs, self.shape_out) self.cache_a = cache(bs, self.shape_out) self.cache_sig = cache(bs, self.shape_out) self.dcache = cache(bs, self.shape_in) end function Gelu:forward(X) local bs = checkshape(X, self.shape_in) if bs ~= self.bs then self:reset_cache(bs) end local Y = self.cache local a = self.cache_a local sig = self.cache_sig -- NOTE: approximate form of GELU exploiting similarities to sigmoid curve. for i = 1, #X do a[i] = 1.704 * X[i] sig[i] = 1 / (1 + exp(-a[i])) Y[i] = X[i] * sig[i] end checkshape(Y, self.shape_out) return Y end function Cos:init() Layer.init(self, "Cos") end function Cos:reset_cache(bs) self.bs = bs self.cache = cache(bs, self.shape_out) self.dcache = cache(bs, self.shape_in) end function Cos:forward(X) local bs = checkshape(X, self.shape_in) if bs ~= self.bs then self:reset_cache(bs) end local Y = self.cache for i = 1, #X do Y[i] = cos(X[i]) end checkshape(Y, self.shape_out) return Y end function Dense:init(dim) Layer.init(self, "Dense") assert(type(dim) == "number") self.dim = dim self.shape_out = {dim} self.coeffs = self:_new_weights(init_he_normal) -- should be normal, but... self.biases = self:_new_weights(init_zeros) end function Dense:make_shape(parent) self.shape_in = parent.shape_out self.coeffs.shape = {self.shape_in[#self.shape_in], self.dim} self.biases.shape = {1, self.dim} end function Dense:reset_cache(bs) self.bs = bs self.cache = cache(bs, self.shape_out) self.cache_x = cache(bs, self.shape_in) self.dcache = cache(bs, self.shape_in) end function Dense:forward(X) local bs = checkshape(X, self.shape_in) if self.bs ~= bs then self:reset_cache(bs) end local Y = self.cache --dot_1aab(X, self.coeffs, Y) dot(X, self.coeffs, 2, 1, Y) for i = 1, self.dim do Y[i] = Y[i] + self.biases[i] end checkshape(Y, self.shape_out) return Y end function Softmax:init() Layer.init(self, "Softmax") end function Softmax:reset_cache(bs) self.bs = bs self.cache = cache(bs, self.shape_out) end function Softmax:forward(X) local bs = checkshape(X, self.shape_in) if self.bs ~= bs then self:reset_cache(bs) end local Y = self.cache local alpha = -huge local num = {} -- TODO: cache local den = 0 for b = 1, X.shape[1] do local l = X.shape[2] local j = (b - 1) * l for i = j+1, j+l do alpha = max(alpha, X[i]) end for i = j+1, j+l do num[i] = exp(X[i] - alpha) end for i = j+1, j+l do den = den + num[i] end for i = j+1, j+l do Y[i] = num[i] / den end end checkshape(Y, self.shape_out) return Y end function Embed:init(vocab, dim) Layer.init(self, "Embed") assert(type(vocab) == "number") assert(type(dim) == "number") self.vocab = vocab self.dim = dim self.lut = self:_new_weights(init_uniform) self.lut.shape = {self.vocab, self.dim} end function Embed:make_shape(parent) self.shape_in = parent.shape_out self.shape_out = {parent.shape_out[1] * self.dim} end function Embed:reset_cache(bs) self.bs = bs self.cache = cache(bs, self.shape_out) self.cache_x = cache(bs, self.shape_in) end function Embed:forward(X) local bs = checkshape(X, self.shape_in) if self.bs ~= bs then self:reset_cache(bs) end local Y = self.cache local yi = 0 for i, x in ipairs(X) do local xi = x * self.dim for j = 1, self.dim do Y[yi+j] = self.lut[xi + j] end yi = yi + self.dim end checkshape(Y, self.shape_out) return Y end function LayerNorm:init(eps) Layer.init(self, "LayerNorm") if eps == nil then eps = 1e-5 end assert(type(eps) == "number") self.eps = eps end function LayerNorm:reset_cache(bs) self.bs = bs self.cache = cache(bs, self.shape_out) end function LayerNorm:forward(X) local bs = checkshape(X, self.shape_in) if self.bs ~= bs then self:reset_cache(bs) end local mean = 0 for i, v in ipairs(X) do mean = mean + v / #X end local var = 0 for i, v in ipairs(X) do local delta = v - mean self.cache[i] = delta var = var + delta * delta / #X end local std = sqrt(var + self.eps) for i, v in ipairs(self.cache) do self.cache[i] = v / std end return self.cache end function Model:init(nodes_in, nodes_out) assert(#nodes_in > 0, #nodes_in) assert(#nodes_out > 0, #nodes_out) --if #nodes_in == 0 and type(nodes_in) == "table" then nodes_in = {nodes_in} end --if #nodes_out == 0 and type(nodes_out) == "table" then nodes_out = {nodes_out} end self.nodes_in = nodes_in self.nodes_out = nodes_out -- find all the used (inbetween) nodes in the graph. self.nodes = traverse_all(self.nodes_in, self.nodes_out) end function Model:reset() self.n_param = 0 for _, node in ipairs(self.nodes) do print(node.name, node:get_size()) node:init_weights() self.n_param = self.n_param + node:get_size() end end function Model:forward(inputs) local values = {} local outputs = {} for i, node in ipairs(self.nodes) do --print(i, node.name) if util.contains(self.nodes_in, node) then local X = inputs[node] assert(X ~= nil, ("missing input for node %s"):format(node.name)) assert(X.shape, ("missing shape for node %s"):format(node.name)) values[node] = node:_propagate({X}) else values[node] = node:propagate(values) end if util.contains(self.nodes_out, node) then outputs[node] = values[node] end end return outputs end function Model:cleargrad() error("TODO") -- TODO end function Model:print() print("digraph G {") for _, parent in ipairs(self.nodes) do if #parent.children then for _, child in ipairs(parent.children) do print('\t'..parent.name..'->'..child.name..';') end end end print('}') end function Model:collect() -- return a flat array of all the weights in the graph. -- if Lua had slices, we wouldn't need this. future library idea? assert(self.n_param >= 0, self.n_param) local W = zeros(self.n_param) local i = 0 for _, node in ipairs(self.nodes) do for _, w in ipairs(node.weights) do for j, v in ipairs(w) do W[i+j] = v end i = i + #w end end return W end function Model:distribute(W) -- inverse operation of collect(). local i = 0 for _, node in ipairs(self.nodes) do for _, w in ipairs(node.weights) do for j, v in ipairs(w) do w[j] = W[i+j] end i = i + #w end end end function Model:default_filename() return ('network%07i.txt'):format(self.n_param) end function Model:save(fn) local fn = fn or self:default_filename() local f = open(fn, 'w') if f == nil then error("Failed to save network to file "..fn) end local W = self:collect() for i, v in ipairs(W) do f:write(v) f:write('\n') end f:close() end function Model:load(fn) local fn = fn or self:default_filename() local f = open(fn, 'r') if f == nil then error("Failed to load network from file "..fn) end local W = zeros(self.n_param) local i = 0 for line in f:lines() do i = i + 1 local n = tonumber(line) if n == nil then error("Failed reading line "..tostring(i).." of file "..fn) end W[i] = n end f:close() self:distribute(W) end return { prod = prod, uniform = uniform, normal = normal, zeros = zeros, arange = arange, allocate = allocate, init_zeros = init_zeros, init_he_uniform = init_he_uniform, init_he_normal = init_he_normal, reshape = reshape, pp = pp, ppi = ppi, dot = dot, traverse = traverse, traverse_all = traverse_all, Weights = Weights, Layer = Layer, Model = Model, Input = Input, Merge = Merge, Relu = Relu, Gelu = Gelu, Cos = Cos, Dense = Dense, Softmax = Softmax, Embed = Embed, LayerNorm = LayerNorm, }