organization etc.

This commit is contained in:
Connor Olding 2017-09-07 19:06:25 +00:00
parent 16b38a7151
commit 2d4ce31c7e

30
nn.lua
View File

@ -14,9 +14,12 @@ local sin = math.sin
local sqrt = math.sqrt
local tostring = tostring
local uniform = math.random
local unpack = table.unpack or unpack
local Base = require("Base")
-- general utilities
local function copy(t) -- shallow copy
local new_t = {}
for k, v in pairs(t) do new_t[k] = v end
@ -33,6 +36,8 @@ local function contains(t, a)
return indexof(t, a) ~= nil
end
-- math utilities
local function prod(x, ...)
if type(x) == "table" then
return prod(unpack(x))
@ -54,6 +59,16 @@ local function zeros(n, out)
return out
end
local function allocate(t, out, init)
out = out or {}
local size = t
if init ~= nil then
return init(zeros(size, out))
else
return zeros(size, out)
end
end
local function init_zeros(t, fan_in, fan_out)
for i = 1, #t do t[i] = 0 end
return t
@ -71,15 +86,7 @@ local function init_he_normal(t, fan_in, fan_out)
return t
end
local function allocate(t, out, init)
out = out or {}
local size = t
if init ~= nil then
return init(zeros(size, out))
else
return zeros(size, out)
end
end
-- nodal
local function traverse(node_in, node_out, nodes, dummy_mode)
-- i have no idea if this is any algorithm in particular.
@ -127,6 +134,8 @@ local function traverse_all(nodes_in, nodes_out, nodes)
return traverse(all_in, all_out, nodes or {}, true)
end
-- classes
local Weights = Base:extend()
local Layer = Base:extend()
local Model = Base:extend()
@ -457,11 +466,12 @@ return {
prod = prod,
normal = normal,
zeros = zeros,
allocate = allocate,
init_zeros = init_zeros,
init_he_uniform = init_he_uniform,
init_he_normal = init_he_normal,
allocate = allocate,
traverse = traverse,
traverse_all = traverse_all,
Weights = Weights,
Layer = Layer,