smbot/nn.lua

775 lines
19 KiB
Lua

local ceil = math.ceil
local cos = math.cos
local exp = math.exp
local floor = math.floor
local insert = table.insert
local ipairs = ipairs
local log = math.log
local max = math.max
local min = math.min
local open = io.open
local pairs = pairs
local pi = math.pi
local print = print
local remove = table.remove
local sin = math.sin
local sqrt = math.sqrt
local tostring = tostring
local uniform = math.random
local unpack = table.unpack or unpack
local Base = require("Base")
-- hacks
local function helpme() print(debug.traceback('helpme', 2):gsub("\n", "\r\n")) end
-- general utilities
local function copy(t, out) -- shallow copy
local out = out or {}
for k, v in pairs(t) do out[k] = v end
return out
end
local function indexof(t, a)
assert(type(t) == "table")
for k, v in pairs(t) do if v == a then return k end end
return nil
end
local function contains(t, a)
return indexof(t, a) ~= nil
end
-- math utilities
local function prod(x, ...)
if type(x) == "table" then
return prod(unpack(x))
end
local ret = x
for i = 1, select("#", ...) do
ret = ret * select(i, ...)
end
return ret
end
local function normal() -- box muller
return sqrt(-2 * log(uniform() + 1e-8) + 1e-8) * cos(2 * pi * uniform())
end
local function zeros(n, out)
out = out or {}
if type(n) == 'table' then
local shape = n
n = prod(shape)
out.shape = shape
end
for i = 1, n do out[i] = 0 end
return out
end
local function arange(n, out)
out = out or {}
if type(n) == 'table' then
local shape = n
n = prod(shape)
out.shape = shape
end
for i = 1, n do out[i] = i - 1 end
return out
end
local function allocate(t, out, init)
out = out or {}
local size = t
if init ~= nil then
return init(zeros(size, out))
else
return zeros(size, out)
end
end
local function init_zeros(t, fan_in, fan_out)
for i = 1, #t do t[i] = 0 end
return t
end
local function init_uniform(t, fan_in, fan_out)
for i = 1, #t do t[i] = uniform() * 2 - 1 end
return t
end
local function init_he_uniform(t, fan_in, fan_out)
local s = sqrt(6 / fan_in)
for i = 1, #t do t[i] = (uniform() * 2 - 1) * s end
return t
end
local function init_he_normal(t, fan_in, fan_out)
local s = sqrt(2 / fan_in)
for i = 1, #t do t[i] = normal() * s end
return t
end
-- ndarray-ish stuff and more involved math
local function pp(t, fmt, sep, ti, di, depth, isfirst, islast)
-- pretty-prints an nd-array.
fmt = fmt or '%10.7f,'
sep = sep or ','
ti = ti or 0
di = di or 1
depth = depth or 0
if t.shape == nil then
local s = '['
for i = 1, #t do s = s..fmt:format(t[i]) end
return s..']'..sep..'\n'
end
local dim = t.shape[di]
local ti_step = 1
for dj = di + 1, #t.shape do ti_step = ti_step * t.shape[dj] end
local indent = ''
for i = 1, depth do indent = indent..' ' end
local s = ''
if di ~= #t.shape then
if isfirst then s = s..indent..'[\n' else s = s..'[\n' end
for i = 1, dim do
s = s..pp(t, fmt, sep, ti, di + 1, depth + 1, i == 1, i == dim)
ti = ti + ti_step
end
if islast then s = s..indent..']'..sep..'\n' else s = s..indent..']'..sep end
else
s = s..indent..'['
for i = ti + 1, ti + dim do s = s..fmt:format(t[i])..sep end
s = s..']'..sep..'\n'
end
return s
end
local function ppi(t, n, ...)
-- TODO: determine maximum number of digits if n is omitted.
n = n or 1
return pp(t, '%'..tostring(n)..'i', ' ', ...)
end
local function checkshape_helper(shape, isbatch)
local s = '{ '
if not isbatch then
s = s..'n, '
end
for i, v in ipairs(shape) do
if not isbatch or i > 1 then
s = s..tostring(v)..(i ~= #shape and ', ' or ' ')
end
end
return s..'}'
end
local function checkshape(batch, shape)
assert(type(batch) == 'table', "batch is not an array")
assert(batch.shape ~= nil, "batch is missing a shape")
if #batch.shape == 1 then
error("batch shape is incomplete", 2)
end
for n=1, #shape do
if batch.shape[n+1] ~= shape[n] then
local s1 = checkshape_helper(batch.shape, true)
local s2 = checkshape_helper(shape, false)
error("shapes do not match: "..s1.." ~= "..s2, 2)
end
end
return batch.shape[1]
end
local function reshape(a, ...)
local new_shape = {...}
assert(#a == prod(new_shape), "new shape does not fit size")
a.shape = new_shape
return a
end
local function cache(bs, shape)
if bs == nil then return nil end
local fullshape = copy(shape)
insert(fullshape, bs, 1)
return zeros(fullshape)
end
local function dot(a, b, ax_a, ax_b, out)
ax_a = ax_a or #a.shape - 0
ax_b = ax_b or #b.shape - 1
assert(a.shape[ax_a] == b.shape[ax_b], "dotted axes do not match")
local dim = a.shape[ax_a]
local out_shape = {}
for di = 1, #a.shape do if di ~= ax_a then insert(out_shape, a.shape[di]) end end
for di = 1, #b.shape do if di ~= ax_b then insert(out_shape, b.shape[di]) end end
if out == nil then
out = zeros(prod(out_shape))
else
assert(prod(out_shape) == #out, "given output is the wrong size")
end
out.shape = out_shape
local a0 = 1
local a1 = 1
local b0 = 1
local b1 = 1
for di = 1, ax_a - 1 do a0 = a0 * a.shape[di] end
for di = 1, ax_b - 1 do b0 = b0 * b.shape[di] end
for di = ax_a + 1, #a.shape do a1 = a1 * a.shape[di] end
for di = ax_b + 1, #b.shape do b1 = b1 * b.shape[di] end
local o = 1
local i_end = a0 * dim - 1
local k_end = b0 * dim - 1
for i = 0, i_end, dim do for j = 1, a1 do
for k = 0, k_end, dim do for m = 1, b1 do
local res = 0
local x = i + j
local y = k + m
for d = 1, dim do
res = res + a[x] * b[y]
x = x + a1
y = y + b1
end
out[o] = res
o = o + 1
end end
end end
return out
end
-- nodal
local function traverse(node_in, node_out, nodes, dummy_mode)
-- i have no idea if this is any algorithm in particular.
nodes = nodes or {}
local seen_up = {}
local q = {node_out}
while #q > 0 do
local node = remove(q, 1)
seen_up[node] = true
for _, parent in ipairs(node.parents) do insert(q, parent) end
end
if dummy_mode then seen_up[node_in] = true end
nodes = {}
q = {node_in}
while #q > 0 do
local node = remove(q, 1)
if seen_up[node] then
local all_parents_added = true
for _, parent in ipairs(node.parents) do
if not contains(nodes, parent) then
all_parents_added = false
break
end
end
if not contains(nodes, node) and all_parents_added then
insert(nodes, node)
end
for _, child in ipairs(node.children) do insert(q, child) end
end
end
if dummy_mode then remove(nodes, indexof(nodes, node_in)) end
return nodes
end
local function traverse_all(nodes_in, nodes_out, nodes)
local all_in = {children={}, parents={}}
local all_out = {children={}, parents={}}
for _, node in ipairs(nodes_in) do insert(all_in.children, node) end
for _, node in ipairs(nodes_out) do insert(all_out.parents, node) end
return traverse(all_in, all_out, nodes or {}, true)
end
-- classes
local Weights = Base:extend()
local Layer = Base:extend()
local Model = Base:extend()
local Input = Layer:extend()
local Merge = Layer:extend()
local Relu = Layer:extend()
local Gelu = Layer:extend()
local Dense = Layer:extend()
local Softmax = Layer:extend()
local Embed = Layer:extend()
function Weights:init(weight_init)
self.weight_init = weight_init
end
function Weights:allocate(fan_in, fan_out)
self.size = prod(self.shape)
return allocate(self.size, self, function(t)
--print('initializing weights of size', self.size, 'with fans', fan_in, fan_out)
return self.weight_init(t, fan_in, fan_out)
end)
end
local counter = {}
function Layer:init(name)
assert(type(name) == "string")
counter[name] = (counter[name] or 0) + 1
self.name = name.."["..tostring(counter[name]).."]"
self.parents = {}
self.children = {}
self.weights = {}
--self.shape_in = nil
--self.shape_out = nil
end
function Layer:make_shape(parent)
if self.shape_in == nil then self.shape_in = parent.shape_out end
if self.shape_out == nil then self.shape_out = self.shape_in end
end
function Layer:feed(child)
assert(self.shape_out ~= nil, "missing output shape: "..self.name)
child:make_shape(self)
insert(self.children, child)
insert(child.parents, self)
return child
end
function Layer:forward()
error("Unimplemented.")
end
function Layer:forward_deterministic(...)
return self:forward(...)
end
function Layer:_new_weights(init)
local w = Weights(init)
insert(self.weights, w)
return w
end
function Layer:get_size()
local size = 0
for i, w in ipairs(self.weights) do size = size + prod(w.shape) end
return size
end
function Layer:init_weights()
for i, w in ipairs(self.weights) do
--print("allocating weights", i, "of", self.name)
for j, v in ipairs(w) do w[j] = nil end -- FIXME: HACK
w:allocate(prod(self.shape_in), prod(self.shape_out))
end
self:reset_cache()
end
function Layer:reset_cache(bs)
self.bs = bs
end
function Layer:_propagate(edges, deterministic)
-- override this if you need multiple parents.
assert(#edges == 1, ("%s edges for node %s (expected 1)"):format(#edges, self.name))
if deterministic then
return self:forward_deterministic(edges[1])
else
return self:forward(edges[1])
end
end
function Layer:propagate(values, deterministic)
local edges = {}
for i, parent in ipairs(self.parents) do
if values[parent] ~= nil then
local X = values[parent]
insert(edges, X)
end
end
assert(#edges > 0, ("%s edges for node %s (expected >0)"):format(#edges, self.name))
local Y = self:_propagate(edges, deterministic)
return Y
end
function Input:init(shape)
Layer.init(self, "Input")
assert(type(shape) == 'table')
self.shape_in = shape
self.shape_out = shape
end
function Input:forward(X)
checkshape(X, self.shape_in)
return X
end
function Merge:init()
Layer.init(self, "Merge")
self.size = 0
self.shape_in = 0
end
function Merge:make_shape(parent)
self.size = self.size + prod(parent.shape_out)
self.shape_in = self.shape_in + 1 -- TODO: more robust.
self.shape_out = {self.size}
end
function Merge:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
end
function Merge:_propagate(edges, deterministic)
assert(#edges == self.shape_in)
local bs = edges[1].shape[1]
if bs ~= self.bs then self:reset_cache(bs) end
local Y = self.cache
local yi = 1
for i, X in ipairs(edges) do
for _, x in ipairs(X) do
Y[yi] = x
yi = yi + 1
end
end
checkshape(Y, self.shape_out)
return Y
end
function Relu:init()
Layer.init(self, "Relu")
end
function Relu:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
self.dcache = cache(bs, self.shape_in)
end
function Relu:forward(X)
local bs = checkshape(X, self.shape_in)
if bs ~= self.bs then self:reset_cache(bs) end
local Y = self.cache
for i = 1, #X do Y[i] = X[i] >= 0 and X[i] or 0 end
checkshape(Y, self.shape_out)
return Y
end
function Gelu:init()
Layer.init(self, "Gelu")
end
function Gelu:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
self.cache_a = cache(bs, self.shape_out)
self.cache_sig = cache(bs, self.shape_out)
self.dcache = cache(bs, self.shape_in)
end
function Gelu:forward(X)
local bs = checkshape(X, self.shape_in)
if bs ~= self.bs then self:reset_cache(bs) end
local Y = self.cache
local a = self.cache_a
local sig = self.cache_sig
-- NOTE: approximate form of GELU exploiting similarities to sigmoid curve.
for i = 1, #X do
a[i] = 1.704 * X[i]
sig[i] = 1 / (1 + exp(-a[i]))
Y[i] = X[i] * sig[i]
end
checkshape(Y, self.shape_out)
return Y
end
function Dense:init(dim)
Layer.init(self, "Dense")
assert(type(dim) == "number")
self.dim = dim
self.shape_out = {dim}
self.coeffs = self:_new_weights(init_he_normal) -- should be normal, but...
self.biases = self:_new_weights(init_zeros)
end
function Dense:make_shape(parent)
self.shape_in = parent.shape_out
self.coeffs.shape = {self.shape_in[#self.shape_in], self.dim}
self.biases.shape = {1, self.dim}
end
function Dense:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
self.cache_x = cache(bs, self.shape_in)
self.dcache = cache(bs, self.shape_in)
end
function Dense:forward(X)
local bs = checkshape(X, self.shape_in)
if self.bs ~= bs then self:reset_cache(bs) end
local Y = self.cache
--dot_1aab(X, self.coeffs, Y)
dot(X, self.coeffs, 2, 1, Y)
for i = 1, self.dim do
Y[i] = Y[i] + self.biases[i]
end
checkshape(Y, self.shape_out)
return Y
end
function Softmax:init()
Layer.init(self, "Softmax")
end
function Softmax:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
end
function Softmax:forward(X)
local bs = checkshape(X, self.shape_in)
if self.bs ~= bs then self:reset_cache(bs) end
local Y = self.cache
local alpha = 0
local num = {} -- TODO: cache
local den = 0
for b = 1, X.shape[1] do
local l = X.shape[2]
local j = (b - 1) * l
for i = j+1, j+l do alpha = max(alpha, X[i]) end
for i = j+1, j+l do num[i] = exp(X[i] - alpha) end
for i = j+1, j+l do den = den + num[i] end
for i = j+1, j+l do Y[i] = num[i] / den end
end
checkshape(Y, self.shape_out)
return Y
end
function Embed:init(vocab, dim)
Layer.init(self, "Embed")
assert(type(vocab) == "number")
assert(type(dim) == "number")
self.vocab = vocab
self.dim = dim
self.lut = self:_new_weights(init_uniform)
self.lut.shape = {self.vocab, self.dim}
end
function Embed:make_shape(parent)
self.shape_in = parent.shape_out
self.shape_out = {parent.shape_out[1] * self.dim}
end
function Embed:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
self.cache_x = cache(bs, self.shape_in)
end
function Embed:forward(X)
local bs = checkshape(X, self.shape_in)
if self.bs ~= bs then self:reset_cache(bs) end
local Y = self.cache
local yi = 0
for i, x in ipairs(X) do
local xi = x * self.dim
for j = 1, self.dim do
Y[yi+j] = self.lut[xi + j]
end
yi = yi + self.dim
end
checkshape(Y, self.shape_out)
return Y
end
function Model:init(nodes_in, nodes_out)
assert(#nodes_in > 0, #nodes_in)
assert(#nodes_out > 0, #nodes_out)
--if #nodes_in == 0 and type(nodes_in) == "table" then nodes_in = {nodes_in} end
--if #nodes_out == 0 and type(nodes_out) == "table" then nodes_out = {nodes_out} end
self.nodes_in = nodes_in
self.nodes_out = nodes_out
-- find all the used (inbetween) nodes in the graph.
self.nodes = traverse_all(self.nodes_in, self.nodes_out)
end
function Model:reset()
self.n_param = 0
for _, node in ipairs(self.nodes) do
print(node.name, node:get_size())
node:init_weights()
self.n_param = self.n_param + node:get_size()
end
end
function Model:forward(inputs)
local values = {}
local outputs = {}
for i, node in ipairs(self.nodes) do
--print(i, node.name)
if contains(self.nodes_in, node) then
local X = inputs[node]
assert(X ~= nil, ("missing input for node %s"):format(node.name))
assert(X.shape, ("missing shape for node %s"):format(node.name))
values[node] = node:_propagate({X})
else
values[node] = node:propagate(values)
end
if contains(self.nodes_out, node) then
outputs[node] = values[node]
end
end
return outputs
end
function Model:cleargrad()
error("TODO") -- TODO
end
function Model:print()
print("digraph G {")
for _, parent in ipairs(self.nodes) do
if #parent.children then
for _, child in ipairs(parent.children) do
print('\t'..parent.name..'->'..child.name..';')
end
end
end
print('}')
end
function Model:collect()
-- return a flat array of all the weights in the graph.
-- if Lua had slices, we wouldn't need this. future library idea?
assert(self.n_param >= 0, self.n_param)
local W = zeros(self.n_param)
local i = 0
for _, node in ipairs(self.nodes) do
for _, w in ipairs(node.weights) do
for j, v in ipairs(w) do
W[i+j] = v
end
i = i + #w
end
end
return W
end
function Model:distribute(W)
-- inverse operation of collect().
local i = 0
for _, node in ipairs(self.nodes) do
for _, w in ipairs(node.weights) do
for j, v in ipairs(w) do
w[j] = W[i+j]
end
i = i + #w
end
end
end
function Model:default_filename()
return ('network%07i.txt'):format(self.n_param)
end
function Model:save(fn)
local fn = fn or self:default_filename()
local f = open(fn, 'w')
if f == nil then error("Failed to save network to file "..fn) end
local W = self:collect()
for i, v in ipairs(W) do
f:write(v)
f:write('\n')
end
f:close()
end
function Model:load(fn)
local fn = fn or self:default_filename()
local f = open(fn, 'r')
if f == nil then
error("Failed to load network from file "..fn)
end
local W = zeros(self.n_param)
local i = 0
for line in f:lines() do
i = i + 1
local n = tonumber(line)
if n == nil then
error("Failed reading line "..tostring(i).." of file "..fn)
end
W[i] = n
end
f:close()
self:distribute(W)
end
return {
copy = copy,
indexof = indexof,
contains = contains,
prod = prod,
uniform = uniform,
normal = normal,
zeros = zeros,
arange = arange,
allocate = allocate,
init_zeros = init_zeros,
init_he_uniform = init_he_uniform,
init_he_normal = init_he_normal,
reshape = reshape,
pp = pp,
ppi = ppi,
dot = dot,
traverse = traverse,
traverse_all = traverse_all,
Weights = Weights,
Layer = Layer,
Model = Model,
Input = Input,
Merge = Merge,
Relu = Relu,
Gelu = Gelu,
Dense = Dense,
Softmax = Softmax,
Embed = Embed,
}