2017-09-07 12:04:36 -07:00
|
|
|
local cos = math.cos
|
|
|
|
local exp = math.exp
|
|
|
|
local insert = table.insert
|
2017-06-28 02:33:18 -07:00
|
|
|
local ipairs = ipairs
|
|
|
|
local log = math.log
|
|
|
|
local max = math.max
|
2017-09-07 12:04:36 -07:00
|
|
|
local min = math.min
|
2017-06-28 21:51:02 -07:00
|
|
|
local open = io.open
|
2017-09-07 12:04:36 -07:00
|
|
|
local pairs = pairs
|
|
|
|
local pi = math.pi
|
|
|
|
local print = print
|
|
|
|
local remove = table.remove
|
|
|
|
local sin = math.sin
|
|
|
|
local sqrt = math.sqrt
|
|
|
|
local tostring = tostring
|
|
|
|
local uniform = math.random
|
2017-09-07 12:06:25 -07:00
|
|
|
local unpack = table.unpack or unpack
|
2017-06-28 02:33:18 -07:00
|
|
|
|
|
|
|
local Base = require("Base")
|
|
|
|
|
2017-09-07 12:06:25 -07:00
|
|
|
-- general utilities
|
|
|
|
|
2017-09-07 12:04:36 -07:00
|
|
|
local function copy(t) -- shallow copy
|
|
|
|
local new_t = {}
|
|
|
|
for k, v in pairs(t) do new_t[k] = v end
|
|
|
|
return new_t
|
|
|
|
end
|
|
|
|
|
|
|
|
local function indexof(t, a)
|
2017-06-28 02:33:18 -07:00
|
|
|
assert(type(t) == "table")
|
2017-09-07 12:04:36 -07:00
|
|
|
for k, v in pairs(t) do if v == a then return k end end
|
|
|
|
return nil
|
|
|
|
end
|
|
|
|
|
|
|
|
local function contains(t, a)
|
|
|
|
return indexof(t, a) ~= nil
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
|
|
|
|
2017-09-07 12:06:25 -07:00
|
|
|
-- math utilities
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
local function prod(x, ...)
|
|
|
|
if type(x) == "table" then
|
|
|
|
return prod(unpack(x))
|
|
|
|
end
|
|
|
|
local ret = x
|
|
|
|
for i = 1, select("#", ...) do
|
|
|
|
ret = ret * select(i, ...)
|
|
|
|
end
|
|
|
|
return ret
|
|
|
|
end
|
|
|
|
|
2017-06-28 17:14:56 -07:00
|
|
|
local function normal() -- box muller
|
|
|
|
return sqrt(-2 * log(uniform() + 1e-8) + 1e-8) * cos(2 * pi * uniform())
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
|
|
|
|
|
|
|
local function zeros(n, out)
|
|
|
|
local out = out or {}
|
|
|
|
for i = 1, n do out[i] = 0 end
|
|
|
|
return out
|
|
|
|
end
|
|
|
|
|
2017-09-07 12:06:25 -07:00
|
|
|
local function allocate(t, out, init)
|
|
|
|
out = out or {}
|
|
|
|
local size = t
|
|
|
|
if init ~= nil then
|
|
|
|
return init(zeros(size, out))
|
|
|
|
else
|
|
|
|
return zeros(size, out)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
local function init_zeros(t, fan_in, fan_out)
|
|
|
|
for i = 1, #t do t[i] = 0 end
|
|
|
|
return t
|
|
|
|
end
|
|
|
|
|
|
|
|
local function init_he_uniform(t, fan_in, fan_out)
|
|
|
|
local s = sqrt(6 / fan_in)
|
|
|
|
for i = 1, #t do t[i] = (uniform() * 2 - 1) * s end
|
|
|
|
return t
|
|
|
|
end
|
|
|
|
|
|
|
|
local function init_he_normal(t, fan_in, fan_out)
|
|
|
|
local s = sqrt(2 / fan_in)
|
|
|
|
for i = 1, #t do t[i] = normal() * s end
|
|
|
|
return t
|
|
|
|
end
|
|
|
|
|
2017-09-07 12:06:25 -07:00
|
|
|
-- nodal
|
2017-06-28 21:51:02 -07:00
|
|
|
|
2017-09-07 12:04:36 -07:00
|
|
|
local function traverse(node_in, node_out, nodes, dummy_mode)
|
|
|
|
-- i have no idea if this is any algorithm in particular.
|
2017-06-28 21:51:02 -07:00
|
|
|
nodes = nodes or {}
|
2017-09-07 12:04:36 -07:00
|
|
|
|
|
|
|
local seen_up = {}
|
|
|
|
local q = {node_out}
|
2017-06-28 21:51:02 -07:00
|
|
|
while #q > 0 do
|
2017-09-07 12:04:36 -07:00
|
|
|
local node = remove(q, 1)
|
|
|
|
seen_up[node] = true
|
|
|
|
for _, parent in ipairs(node.parents) do insert(q, parent) end
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
2017-06-28 21:51:02 -07:00
|
|
|
|
2017-09-07 12:04:36 -07:00
|
|
|
if dummy_mode then seen_up[node_in] = true end
|
|
|
|
|
|
|
|
nodes = {}
|
|
|
|
q = {node_in}
|
|
|
|
while #q > 0 do
|
|
|
|
local node = remove(q, 1)
|
|
|
|
if seen_up[node] then
|
|
|
|
local all_parents_added = true
|
|
|
|
for _, parent in ipairs(node.parents) do
|
|
|
|
if not contains(nodes, parent) then
|
|
|
|
all_parents_added = false
|
|
|
|
break
|
|
|
|
end
|
|
|
|
end
|
|
|
|
if not contains(nodes, node) and all_parents_added then
|
|
|
|
insert(nodes, node)
|
|
|
|
end
|
|
|
|
for _, child in ipairs(node.children) do insert(q, child) end
|
2017-06-28 21:51:02 -07:00
|
|
|
end
|
|
|
|
end
|
2017-09-07 12:04:36 -07:00
|
|
|
|
|
|
|
if dummy_mode then remove(nodes, indexof(nodes, node_in)) end
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
return nodes
|
|
|
|
end
|
|
|
|
|
|
|
|
local function traverse_all(nodes_in, nodes_out, nodes)
|
2017-09-07 12:04:36 -07:00
|
|
|
local all_in = {children={}, parents={}}
|
|
|
|
local all_out = {children={}, parents={}}
|
2017-06-28 21:51:02 -07:00
|
|
|
for _, node in ipairs(nodes_in) do insert(all_in.children, node) end
|
|
|
|
for _, node in ipairs(nodes_out) do insert(all_out.parents, node) end
|
2017-09-07 12:04:36 -07:00
|
|
|
return traverse(all_in, all_out, nodes or {}, true)
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
|
|
|
|
2017-09-07 12:06:25 -07:00
|
|
|
-- classes
|
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
local Weights = Base:extend()
|
|
|
|
local Layer = Base:extend()
|
|
|
|
local Model = Base:extend()
|
|
|
|
local Input = Layer:extend()
|
|
|
|
local Relu = Layer:extend()
|
2017-06-29 02:50:33 -07:00
|
|
|
local Gelu = Layer:extend()
|
2017-06-28 02:33:18 -07:00
|
|
|
local Dense = Layer:extend()
|
|
|
|
local Softmax = Layer:extend()
|
|
|
|
|
|
|
|
function Weights:init(weight_init)
|
|
|
|
self.weight_init = weight_init
|
|
|
|
end
|
|
|
|
|
|
|
|
function Weights:allocate(fan_in, fan_out)
|
2017-06-28 21:51:02 -07:00
|
|
|
self.size = prod(self.shape)
|
2017-06-28 02:33:18 -07:00
|
|
|
return allocate(self.size, self, function(t)
|
|
|
|
--print('initializing weights of size', self.size, 'with fans', fan_in, fan_out)
|
|
|
|
return self.weight_init(t, fan_in, fan_out)
|
|
|
|
end)
|
|
|
|
end
|
|
|
|
|
|
|
|
local counter = {}
|
|
|
|
function Layer:init(name)
|
|
|
|
assert(type(name) == "string")
|
|
|
|
counter[name] = (counter[name] or 0) + 1
|
|
|
|
self.name = name.."["..tostring(counter[name]).."]"
|
|
|
|
self.parents = {}
|
|
|
|
self.children = {}
|
|
|
|
self.weights = {}
|
|
|
|
--self.size_in = nil
|
|
|
|
--self.size_out = nil
|
|
|
|
end
|
|
|
|
|
|
|
|
function Layer:make_shape(parent)
|
|
|
|
if self.size_in == nil then self.size_in = parent.size_out end
|
|
|
|
if self.size_out == nil then self.size_out = self.size_in end
|
|
|
|
end
|
|
|
|
|
|
|
|
function Layer:feed(child)
|
|
|
|
assert(self.size_out ~= nil)
|
|
|
|
child:make_shape(self)
|
|
|
|
insert(self.children, child)
|
|
|
|
insert(child.parents, self)
|
|
|
|
return child
|
|
|
|
end
|
|
|
|
|
|
|
|
function Layer:forward()
|
|
|
|
error("Unimplemented.")
|
|
|
|
end
|
|
|
|
|
|
|
|
function Layer:forward_deterministic(...)
|
|
|
|
return self:forward(...)
|
|
|
|
end
|
|
|
|
|
2017-07-05 20:26:27 -07:00
|
|
|
function Layer:backward()
|
|
|
|
error("Unimplemented.")
|
|
|
|
end
|
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
function Layer:_new_weights(init)
|
|
|
|
local w = Weights(init)
|
|
|
|
insert(self.weights, w)
|
|
|
|
return w
|
|
|
|
end
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
function Layer:get_size()
|
2017-06-28 02:33:18 -07:00
|
|
|
local size = 0
|
|
|
|
for i, w in ipairs(self.weights) do size = size + prod(w.size) end
|
|
|
|
return size
|
|
|
|
end
|
|
|
|
|
|
|
|
function Layer:init_weights()
|
|
|
|
for i, w in ipairs(self.weights) do
|
|
|
|
--print("allocating weights", i, "of", self.name)
|
|
|
|
for j, v in ipairs(w) do w[j] = nil end -- FIXME: HACK
|
|
|
|
w:allocate(self.size_in, self.size_out)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
function Layer:_propagate(edges, deterministic)
|
|
|
|
assert(#edges == 1, #edges) -- override this if you need multiple parents.
|
|
|
|
if deterministic then
|
|
|
|
return self:forward_deterministic(edges[1])
|
|
|
|
else
|
|
|
|
return self:forward(edges[1])
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
function Layer:propagate(values, deterministic)
|
|
|
|
local edges = {}
|
|
|
|
for i, parent in ipairs(self.parents) do
|
|
|
|
if values[parent] ~= nil then
|
|
|
|
local X = values[parent]
|
|
|
|
insert(edges, X)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
assert(#edges > 0, #edges)
|
|
|
|
local Y = self:_propagate(edges, deterministic)
|
|
|
|
return Y
|
|
|
|
end
|
|
|
|
|
|
|
|
function Input:init(size)
|
|
|
|
Layer.init(self, "Input")
|
|
|
|
assert(type(size) == 'number')
|
|
|
|
self.size_in = size
|
|
|
|
self.size_out = size
|
|
|
|
end
|
|
|
|
|
|
|
|
function Input:forward(X)
|
|
|
|
assert(#X == self.size_in)
|
|
|
|
return X
|
|
|
|
end
|
|
|
|
|
2017-07-05 20:26:27 -07:00
|
|
|
function Input:backward(dY)
|
|
|
|
assert(#dY == self.size_out)
|
|
|
|
return zeros(#dY)
|
|
|
|
end
|
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
function Relu:init()
|
|
|
|
Layer.init(self, "Relu")
|
|
|
|
end
|
|
|
|
|
|
|
|
function Relu:forward(X)
|
|
|
|
assert(#X == self.size_in)
|
|
|
|
self.cache = self.cache or zeros(self.size_out)
|
|
|
|
local Y = self.cache
|
|
|
|
|
|
|
|
for i = 1, #X do Y[i] = X[i] >= 0 and X[i] or 0 end
|
|
|
|
|
|
|
|
assert(#Y == self.size_out)
|
|
|
|
return Y
|
|
|
|
end
|
|
|
|
|
2017-07-05 20:26:27 -07:00
|
|
|
function Relu:backward(dY)
|
|
|
|
assert(#dY == self.size_out)
|
|
|
|
self.dcache = self.dcache or zeros(self.size_in)
|
|
|
|
local Y = self.cache
|
|
|
|
local dX = self.dcache
|
|
|
|
|
|
|
|
for i = 1, #dY do dX[i] = Y[i] >= 0 and dY[i] or 0 end
|
|
|
|
|
|
|
|
assert(#Y == self.size_in)
|
|
|
|
return Y
|
|
|
|
end
|
|
|
|
|
2017-06-29 02:50:33 -07:00
|
|
|
function Gelu:init()
|
|
|
|
Layer.init(self, "Gelu")
|
|
|
|
end
|
|
|
|
|
|
|
|
function Gelu:forward(X)
|
|
|
|
assert(#X == self.size_in)
|
|
|
|
self.cache = self.cache or zeros(self.size_out)
|
|
|
|
local Y = self.cache
|
|
|
|
|
|
|
|
-- NOTE: approximate form of GELU exploiting similarities to sigmoid curve.
|
|
|
|
for i = 1, #X do
|
|
|
|
Y[i] = X[i] / (1 + exp(-1.704 * X[i]))
|
|
|
|
end
|
|
|
|
|
|
|
|
assert(#Y == self.size_out)
|
|
|
|
return Y
|
|
|
|
end
|
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
function Dense:init(dim)
|
|
|
|
Layer.init(self, "Dense")
|
|
|
|
assert(type(dim) == "number")
|
|
|
|
self.dim = dim
|
|
|
|
self.size_out = dim
|
|
|
|
self.coeffs = self:_new_weights(init_he_normal) -- should be normal, but...
|
|
|
|
self.biases = self:_new_weights(init_zeros)
|
|
|
|
end
|
|
|
|
|
|
|
|
function Dense:make_shape(parent)
|
|
|
|
self.size_in = parent.size_out
|
2017-06-28 21:51:02 -07:00
|
|
|
self.coeffs.shape = {self.size_in, self.dim}
|
|
|
|
self.biases.shape = self.dim
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
|
|
|
|
|
|
|
function Dense:forward(X)
|
|
|
|
assert(#X == self.size_in)
|
|
|
|
self.cache = self.cache or zeros(self.size_out)
|
|
|
|
local Y = self.cache
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
for i = 1, self.dim do
|
2017-06-28 02:33:18 -07:00
|
|
|
local res = 0
|
2017-06-28 21:51:02 -07:00
|
|
|
local c = (i - 1) * #X
|
2017-06-28 02:33:18 -07:00
|
|
|
for j = 1, #X do
|
2017-06-28 21:51:02 -07:00
|
|
|
res = res + X[j] * self.coeffs[c + j]
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
|
|
|
Y[i] = res + self.biases[i]
|
|
|
|
end
|
|
|
|
|
|
|
|
assert(#Y == self.size_out)
|
|
|
|
return Y
|
|
|
|
end
|
|
|
|
|
|
|
|
function Softmax:init()
|
|
|
|
Layer.init(self, "Softmax")
|
|
|
|
end
|
|
|
|
|
|
|
|
function Softmax:forward(X)
|
|
|
|
assert(#X == self.size_in)
|
|
|
|
self.cache = self.cache or zeros(self.size_out)
|
|
|
|
local Y = self.cache
|
|
|
|
|
|
|
|
local alpha = 0
|
|
|
|
local num = {} -- TODO: cache
|
|
|
|
local den = 0
|
|
|
|
|
|
|
|
for i = 1, #X do alpha = max(alpha, X[i]) end
|
|
|
|
for i = 1, #X do num[i] = exp(X[i] - alpha) end
|
|
|
|
for i = 1, #X do den = den + num[i] end
|
|
|
|
for i = 1, #X do Y[i] = num[i] / den end
|
|
|
|
|
|
|
|
assert(#Y == self.size_out)
|
|
|
|
return Y
|
|
|
|
end
|
|
|
|
|
|
|
|
function Model:init(nodes_in, nodes_out)
|
|
|
|
assert(#nodes_in > 0, #nodes_in)
|
|
|
|
assert(#nodes_out > 0, #nodes_out)
|
|
|
|
--if #nodes_in == 0 and type(nodes_in) == "table" then nodes_in = {nodes_in} end
|
|
|
|
--if #nodes_out == 0 and type(nodes_out) == "table" then nodes_out = {nodes_out} end
|
|
|
|
|
|
|
|
self.nodes_in = nodes_in
|
|
|
|
self.nodes_out = nodes_out
|
|
|
|
|
|
|
|
-- find all the used (inbetween) nodes in the graph.
|
|
|
|
self.nodes = traverse_all(self.nodes_in, self.nodes_out)
|
|
|
|
end
|
|
|
|
|
|
|
|
function Model:reset()
|
|
|
|
self.n_param = 0
|
|
|
|
for _, node in ipairs(self.nodes) do
|
|
|
|
node:init_weights()
|
2017-06-28 21:51:02 -07:00
|
|
|
self.n_param = self.n_param + node:get_size()
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2017-09-07 12:09:44 -07:00
|
|
|
function Model:forward(inputs)
|
2017-06-28 02:33:18 -07:00
|
|
|
local values = {}
|
|
|
|
local outputs = {}
|
|
|
|
for i, node in ipairs(self.nodes) do
|
|
|
|
--print(i, node.name)
|
|
|
|
if contains(self.nodes_in, node) then
|
2017-09-07 12:09:44 -07:00
|
|
|
local X = inputs[node]
|
|
|
|
assert(X ~= nil, ("missing input for node %s"):format(node.name))
|
2017-06-28 02:33:18 -07:00
|
|
|
values[node] = node:_propagate({X})
|
|
|
|
else
|
|
|
|
values[node] = node:propagate(values)
|
|
|
|
end
|
|
|
|
if contains(self.nodes_out, node) then
|
|
|
|
outputs[node] = values[node]
|
|
|
|
end
|
|
|
|
end
|
|
|
|
return outputs
|
|
|
|
end
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
function Model:collect()
|
|
|
|
-- return a flat array of all the weights in the graph.
|
|
|
|
-- if Lua had slices, we wouldn't need this. future library idea?
|
|
|
|
assert(self.n_param >= 0, self.n_param)
|
|
|
|
local W = zeros(self.n_param)
|
|
|
|
local i = 0
|
|
|
|
for _, node in ipairs(self.nodes) do
|
|
|
|
for _, w in ipairs(node.weights) do
|
|
|
|
for j, v in ipairs(w) do
|
|
|
|
W[i+j] = v
|
|
|
|
end
|
|
|
|
i = i + #w
|
|
|
|
end
|
|
|
|
end
|
|
|
|
return W
|
|
|
|
end
|
|
|
|
|
|
|
|
function Model:distribute(W)
|
|
|
|
-- inverse operation of collect().
|
|
|
|
local i = 0
|
|
|
|
for _, node in ipairs(self.nodes) do
|
|
|
|
for _, w in ipairs(node.weights) do
|
|
|
|
for j, v in ipairs(w) do
|
|
|
|
w[j] = W[i+j]
|
|
|
|
end
|
|
|
|
i = i + #w
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2017-07-05 20:26:27 -07:00
|
|
|
function Model:default_filename()
|
|
|
|
return ('network%07i.txt'):format(self.n_param)
|
|
|
|
end
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
function Model:save(fn)
|
2017-07-05 20:26:27 -07:00
|
|
|
local fn = fn or self:default_filename()
|
2017-06-28 21:51:02 -07:00
|
|
|
local f = open(fn, 'w')
|
|
|
|
if f == nil then error("Failed to save network to file "..fn) end
|
|
|
|
local W = self:collect()
|
|
|
|
for i, v in ipairs(W) do
|
|
|
|
f:write(v)
|
|
|
|
f:write('\n')
|
|
|
|
end
|
|
|
|
f:close()
|
|
|
|
end
|
|
|
|
|
|
|
|
function Model:load(fn)
|
2017-07-05 20:26:27 -07:00
|
|
|
local fn = fn or self:default_filename()
|
2017-06-28 21:51:02 -07:00
|
|
|
local f = open(fn, 'r')
|
|
|
|
if f == nil then
|
|
|
|
error("Failed to load network from file "..fn)
|
|
|
|
end
|
|
|
|
|
|
|
|
local W = zeros(self.n_param)
|
|
|
|
local i = 0
|
|
|
|
for line in f:lines() do
|
|
|
|
i = i + 1
|
|
|
|
local n = tonumber(line)
|
|
|
|
if n == nil then
|
|
|
|
error("Failed reading line "..tostring(i).." of file "..fn)
|
|
|
|
end
|
|
|
|
W[i] = n
|
|
|
|
end
|
|
|
|
f:close()
|
|
|
|
|
|
|
|
self:distribute(W)
|
|
|
|
end
|
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
return {
|
2017-06-28 21:51:02 -07:00
|
|
|
copy = copy,
|
2017-09-07 12:04:36 -07:00
|
|
|
indexof = indexof,
|
|
|
|
contains = contains,
|
|
|
|
prod = prod,
|
|
|
|
normal = normal,
|
2017-06-28 02:33:18 -07:00
|
|
|
zeros = zeros,
|
2017-09-07 12:06:25 -07:00
|
|
|
allocate = allocate,
|
2017-06-28 02:33:18 -07:00
|
|
|
init_zeros = init_zeros,
|
|
|
|
init_he_uniform = init_he_uniform,
|
|
|
|
init_he_normal = init_he_normal,
|
2017-09-07 12:04:36 -07:00
|
|
|
traverse = traverse,
|
2017-09-07 12:06:25 -07:00
|
|
|
traverse_all = traverse_all,
|
2017-06-28 02:33:18 -07:00
|
|
|
|
|
|
|
Weights = Weights,
|
|
|
|
Layer = Layer,
|
|
|
|
Model = Model,
|
|
|
|
Input = Input,
|
|
|
|
Relu = Relu,
|
2017-06-29 02:50:33 -07:00
|
|
|
Gelu = Gelu,
|
2017-06-28 02:33:18 -07:00
|
|
|
Dense = Dense,
|
|
|
|
Softmax = Softmax,
|
|
|
|
}
|