smbot/nn.lua

379 lines
8.9 KiB
Lua
Raw Normal View History

2017-06-28 17:14:56 -07:00
local print = print
2017-06-28 02:33:18 -07:00
local tostring = tostring
local ipairs = ipairs
local pairs = pairs
local uniform = math.random
local sqrt = math.sqrt
local log = math.log
local pi = math.pi
local exp = math.exp
local min = math.min
local max = math.max
local cos = math.cos
local sin = math.sin
local insert = table.insert
local remove = table.remove
local bor = bit.bor
local Base = require("Base")
local function contains(t, a)
assert(type(t) == "table")
for k, v in pairs(t) do if v == a then return true end end
return false
end
2017-06-28 17:14:56 -07:00
local function normal() -- box muller
return sqrt(-2 * log(uniform() + 1e-8) + 1e-8) * cos(2 * pi * uniform())
2017-06-28 02:33:18 -07:00
end
local function zeros(n, out)
local out = out or {}
for i = 1, n do out[i] = 0 end
return out
end
local function init_zeros(t, fan_in, fan_out)
for i = 1, #t do t[i] = 0 end
return t
end
local function init_he_uniform(t, fan_in, fan_out)
local s = sqrt(6 / fan_in)
for i = 1, #t do t[i] = (uniform() * 2 - 1) * s end
return t
end
local function init_he_normal(t, fan_in, fan_out)
local s = sqrt(2 / fan_in)
for i = 1, #t do t[i] = normal() * s end
return t
end
local function copy(t) -- shallow copy
local new_t = {}
for k, v in pairs(t) do new_t[k] = v end
return new_t
end
local function allocate(t, out, init)
-- FIXME: this code is fucking disgusting.
out = out or {}
assert(type(out) == "table", type(out))
if type(t) == "number" then
local size = t
if init ~= nil then
return init(zeros(size, out))
else
return zeros(size, out)
end
end
local topsize = t[1]
t = copy(t)
remove(t, 1)
if #t == 1 then t = t[1] end
for i = 1, topsize do
local res = allocate(t, nil, init)
assert(res ~= nil)
insert(out, res)
end
return out
end
local Weights = Base:extend()
local Layer = Base:extend()
local Model = Base:extend()
local Input = Layer:extend()
local Relu = Layer:extend()
local Dense = Layer:extend()
local Softmax = Layer:extend()
function Weights:init(weight_init)
self.weight_init = weight_init
end
function Weights:allocate(fan_in, fan_out)
return allocate(self.size, self, function(t)
--print('initializing weights of size', self.size, 'with fans', fan_in, fan_out)
return self.weight_init(t, fan_in, fan_out)
end)
end
--[[
local w = Weights(init_he_uniform)
w.size = {16, 16}
w:allocate(16, 16)
print(w)
do return end
local w = zeros(16)
for i = 1, #w do w[i] = normal() * 1920 / 2560 end
print(w)
--]]
local counter = {}
function Layer:init(name)
assert(type(name) == "string")
counter[name] = (counter[name] or 0) + 1
self.name = name.."["..tostring(counter[name]).."]"
self.parents = {}
self.children = {}
self.weights = {}
--self.size_in = nil
--self.size_out = nil
end
function Layer:make_shape(parent)
if self.size_in == nil then self.size_in = parent.size_out end
if self.size_out == nil then self.size_out = self.size_in end
end
function Layer:feed(child)
assert(self.size_out ~= nil)
child:make_shape(self)
insert(self.children, child)
insert(child.parents, self)
return child
end
function Layer:forward()
error("Unimplemented.")
end
function Layer:forward_deterministic(...)
return self:forward(...)
end
function Layer:_new_weights(init)
local w = Weights(init)
insert(self.weights, w)
return w
end
local function prod(x, ...)
if type(x) == "table" then
return prod(unpack(x))
end
local ret = x
for i = 1, select("#", ...) do
ret = ret * select(i, ...)
end
return ret
end
function Layer:getsize()
local size = 0
for i, w in ipairs(self.weights) do size = size + prod(w.size) end
return size
end
function Layer:init_weights()
for i, w in ipairs(self.weights) do
--print("allocating weights", i, "of", self.name)
for j, v in ipairs(w) do w[j] = nil end -- FIXME: HACK
w:allocate(self.size_in, self.size_out)
end
end
function Layer:_propagate(edges, deterministic)
assert(#edges == 1, #edges) -- override this if you need multiple parents.
if deterministic then
return self:forward_deterministic(edges[1])
else
return self:forward(edges[1])
end
end
function Layer:propagate(values, deterministic)
local edges = {}
for i, parent in ipairs(self.parents) do
if values[parent] ~= nil then
local X = values[parent]
insert(edges, X)
end
end
assert(#edges > 0, #edges)
local Y = self:_propagate(edges, deterministic)
return Y
end
function Input:init(size)
Layer.init(self, "Input")
assert(type(size) == 'number')
self.size_in = size
self.size_out = size
end
function Input:forward(X)
assert(#X == self.size_in)
return X
end
function Relu:init()
Layer.init(self, "Relu")
end
function Relu:forward(X)
assert(#X == self.size_in)
self.cache = self.cache or zeros(self.size_out)
local Y = self.cache
for i = 1, #X do Y[i] = X[i] >= 0 and X[i] or 0 end
assert(#Y == self.size_out)
return Y
end
function Dense:init(dim)
Layer.init(self, "Dense")
assert(type(dim) == "number")
self.dim = dim
self.size_out = dim
self.coeffs = self:_new_weights(init_he_normal) -- should be normal, but...
self.biases = self:_new_weights(init_zeros)
end
function Dense:make_shape(parent)
self.size_in = parent.size_out
self.coeffs.size = {self.dim, self.size_in}
self.biases.size = self.dim
end
function Dense:forward(X)
assert(#X == self.size_in)
self.cache = self.cache or zeros(self.size_out)
local Y = self.cache
for i = 1, #self.coeffs do
local res = 0
local c = self.coeffs[i]
for j = 1, #X do
res = res + X[j] * c[j]
end
Y[i] = res + self.biases[i]
end
assert(#Y == self.size_out)
return Y
end
function Softmax:init()
Layer.init(self, "Softmax")
end
function Softmax:forward(X)
assert(#X == self.size_in)
self.cache = self.cache or zeros(self.size_out)
local Y = self.cache
local alpha = 0
local num = {} -- TODO: cache
local den = 0
for i = 1, #X do alpha = max(alpha, X[i]) end
for i = 1, #X do num[i] = exp(X[i] - alpha) end
for i = 1, #X do den = den + num[i] end
for i = 1, #X do Y[i] = num[i] / den end
assert(#Y == self.size_out)
return Y
end
local function levelorder(field, node_in, nodes)
-- horribly inefficient.
nodes = nodes or {}
local q = {node_in}
while #q > 0 do
local node = q[1]
remove(q, 1)
insert(nodes, node)
for _, child in ipairs(node[field]) do
q[#q+1] = child
end
end
return nodes
end
local function traverse(node_in, node_out, nodes)
nodes = nodes or {}
local down = levelorder('children', node_in, {})
local up = levelorder('parents', node_out, {})
local seen = {}
for _, node in ipairs(up) do
seen[node] = bor(seen[node] or 0, 1)
end
for _, node in ipairs(down) do
seen[node] = bor(seen[node] or 0, 2)
if seen[node] == 3 then
insert(nodes, node)
end
end
return nodes
end
local function traverse_all(nodes_in, nodes_out, nodes)
local all_in = {children={}}
local all_out = {parents={}}
for _, node in ipairs(nodes_in) do insert(all_in.children, node) end
for _, node in ipairs(nodes_out) do insert(all_out.parents, node) end
return traverse(all_in, all_out, nodes or {})
end
function Model:init(nodes_in, nodes_out)
assert(#nodes_in > 0, #nodes_in)
assert(#nodes_out > 0, #nodes_out)
--if #nodes_in == 0 and type(nodes_in) == "table" then nodes_in = {nodes_in} end
--if #nodes_out == 0 and type(nodes_out) == "table" then nodes_out = {nodes_out} end
self.nodes_in = nodes_in
self.nodes_out = nodes_out
-- find all the used (inbetween) nodes in the graph.
self.nodes = traverse_all(self.nodes_in, self.nodes_out)
end
function Model:reset()
self.n_param = 0
for _, node in ipairs(self.nodes) do
node:init_weights()
self.n_param = self.n_param + node:getsize()
end
end
function Model:forward(X)
local values = {}
local outputs = {}
for i, node in ipairs(self.nodes) do
--print(i, node.name)
if contains(self.nodes_in, node) then
values[node] = node:_propagate({X})
else
values[node] = node:propagate(values)
end
if contains(self.nodes_out, node) then
outputs[node] = values[node]
end
end
return outputs
end
return {
uniform = uniform,
normal = normal,
zeros = zeros,
init_zeros = init_zeros,
init_he_uniform = init_he_uniform,
init_he_normal = init_he_normal,
Weights = Weights,
Layer = Layer,
Model = Model,
Input = Input,
Relu = Relu,
Dense = Dense,
Softmax = Softmax,
}