organization etc.

This commit is contained in:
Connor Olding 2017-09-07 19:06:25 +00:00
parent 16b38a7151
commit 2d4ce31c7e

30
nn.lua
View File

@ -14,9 +14,12 @@ local sin = math.sin
local sqrt = math.sqrt local sqrt = math.sqrt
local tostring = tostring local tostring = tostring
local uniform = math.random local uniform = math.random
local unpack = table.unpack or unpack
local Base = require("Base") local Base = require("Base")
-- general utilities
local function copy(t) -- shallow copy local function copy(t) -- shallow copy
local new_t = {} local new_t = {}
for k, v in pairs(t) do new_t[k] = v end for k, v in pairs(t) do new_t[k] = v end
@ -33,6 +36,8 @@ local function contains(t, a)
return indexof(t, a) ~= nil return indexof(t, a) ~= nil
end end
-- math utilities
local function prod(x, ...) local function prod(x, ...)
if type(x) == "table" then if type(x) == "table" then
return prod(unpack(x)) return prod(unpack(x))
@ -54,6 +59,16 @@ local function zeros(n, out)
return out return out
end end
local function allocate(t, out, init)
out = out or {}
local size = t
if init ~= nil then
return init(zeros(size, out))
else
return zeros(size, out)
end
end
local function init_zeros(t, fan_in, fan_out) local function init_zeros(t, fan_in, fan_out)
for i = 1, #t do t[i] = 0 end for i = 1, #t do t[i] = 0 end
return t return t
@ -71,15 +86,7 @@ local function init_he_normal(t, fan_in, fan_out)
return t return t
end end
local function allocate(t, out, init) -- nodal
out = out or {}
local size = t
if init ~= nil then
return init(zeros(size, out))
else
return zeros(size, out)
end
end
local function traverse(node_in, node_out, nodes, dummy_mode) local function traverse(node_in, node_out, nodes, dummy_mode)
-- i have no idea if this is any algorithm in particular. -- i have no idea if this is any algorithm in particular.
@ -127,6 +134,8 @@ local function traverse_all(nodes_in, nodes_out, nodes)
return traverse(all_in, all_out, nodes or {}, true) return traverse(all_in, all_out, nodes or {}, true)
end end
-- classes
local Weights = Base:extend() local Weights = Base:extend()
local Layer = Base:extend() local Layer = Base:extend()
local Model = Base:extend() local Model = Base:extend()
@ -457,11 +466,12 @@ return {
prod = prod, prod = prod,
normal = normal, normal = normal,
zeros = zeros, zeros = zeros,
allocate = allocate,
init_zeros = init_zeros, init_zeros = init_zeros,
init_he_uniform = init_he_uniform, init_he_uniform = init_he_uniform,
init_he_normal = init_he_normal, init_he_normal = init_he_normal,
allocate = allocate,
traverse = traverse, traverse = traverse,
traverse_all = traverse_all,
Weights = Weights, Weights = Weights,
Layer = Layer, Layer = Layer,