From 2d4ce31c7efcea984bc22fadabe3dcf496ffaf5e Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Thu, 7 Sep 2017 19:06:25 +0000 Subject: [PATCH] organization etc. --- nn.lua | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/nn.lua b/nn.lua index 76d18b1..48e123a 100644 --- a/nn.lua +++ b/nn.lua @@ -14,9 +14,12 @@ local sin = math.sin local sqrt = math.sqrt local tostring = tostring local uniform = math.random +local unpack = table.unpack or unpack local Base = require("Base") +-- general utilities + local function copy(t) -- shallow copy local new_t = {} for k, v in pairs(t) do new_t[k] = v end @@ -33,6 +36,8 @@ local function contains(t, a) return indexof(t, a) ~= nil end +-- math utilities + local function prod(x, ...) if type(x) == "table" then return prod(unpack(x)) @@ -54,6 +59,16 @@ local function zeros(n, out) return out end +local function allocate(t, out, init) + out = out or {} + local size = t + if init ~= nil then + return init(zeros(size, out)) + else + return zeros(size, out) + end +end + local function init_zeros(t, fan_in, fan_out) for i = 1, #t do t[i] = 0 end return t @@ -71,15 +86,7 @@ local function init_he_normal(t, fan_in, fan_out) return t end -local function allocate(t, out, init) - out = out or {} - local size = t - if init ~= nil then - return init(zeros(size, out)) - else - return zeros(size, out) - end -end +-- nodal local function traverse(node_in, node_out, nodes, dummy_mode) -- i have no idea if this is any algorithm in particular. @@ -127,6 +134,8 @@ local function traverse_all(nodes_in, nodes_out, nodes) return traverse(all_in, all_out, nodes or {}, true) end +-- classes + local Weights = Base:extend() local Layer = Base:extend() local Model = Base:extend() @@ -457,11 +466,12 @@ return { prod = prod, normal = normal, zeros = zeros, + allocate = allocate, init_zeros = init_zeros, init_he_uniform = init_he_uniform, init_he_normal = init_he_normal, - allocate = allocate, traverse = traverse, + traverse_all = traverse_all, Weights = Weights, Layer = Layer,