organization etc.
This commit is contained in:
parent
16b38a7151
commit
2d4ce31c7e
1 changed files with 20 additions and 10 deletions
30
nn.lua
30
nn.lua
|
@ -14,9 +14,12 @@ local sin = math.sin
|
|||
local sqrt = math.sqrt
|
||||
local tostring = tostring
|
||||
local uniform = math.random
|
||||
local unpack = table.unpack or unpack
|
||||
|
||||
local Base = require("Base")
|
||||
|
||||
-- general utilities
|
||||
|
||||
local function copy(t) -- shallow copy
|
||||
local new_t = {}
|
||||
for k, v in pairs(t) do new_t[k] = v end
|
||||
|
@ -33,6 +36,8 @@ local function contains(t, a)
|
|||
return indexof(t, a) ~= nil
|
||||
end
|
||||
|
||||
-- math utilities
|
||||
|
||||
local function prod(x, ...)
|
||||
if type(x) == "table" then
|
||||
return prod(unpack(x))
|
||||
|
@ -54,6 +59,16 @@ local function zeros(n, out)
|
|||
return out
|
||||
end
|
||||
|
||||
local function allocate(t, out, init)
|
||||
out = out or {}
|
||||
local size = t
|
||||
if init ~= nil then
|
||||
return init(zeros(size, out))
|
||||
else
|
||||
return zeros(size, out)
|
||||
end
|
||||
end
|
||||
|
||||
local function init_zeros(t, fan_in, fan_out)
|
||||
for i = 1, #t do t[i] = 0 end
|
||||
return t
|
||||
|
@ -71,15 +86,7 @@ local function init_he_normal(t, fan_in, fan_out)
|
|||
return t
|
||||
end
|
||||
|
||||
local function allocate(t, out, init)
|
||||
out = out or {}
|
||||
local size = t
|
||||
if init ~= nil then
|
||||
return init(zeros(size, out))
|
||||
else
|
||||
return zeros(size, out)
|
||||
end
|
||||
end
|
||||
-- nodal
|
||||
|
||||
local function traverse(node_in, node_out, nodes, dummy_mode)
|
||||
-- i have no idea if this is any algorithm in particular.
|
||||
|
@ -127,6 +134,8 @@ local function traverse_all(nodes_in, nodes_out, nodes)
|
|||
return traverse(all_in, all_out, nodes or {}, true)
|
||||
end
|
||||
|
||||
-- classes
|
||||
|
||||
local Weights = Base:extend()
|
||||
local Layer = Base:extend()
|
||||
local Model = Base:extend()
|
||||
|
@ -457,11 +466,12 @@ return {
|
|||
prod = prod,
|
||||
normal = normal,
|
||||
zeros = zeros,
|
||||
allocate = allocate,
|
||||
init_zeros = init_zeros,
|
||||
init_he_uniform = init_he_uniform,
|
||||
init_he_normal = init_he_normal,
|
||||
allocate = allocate,
|
||||
traverse = traverse,
|
||||
traverse_all = traverse_all,
|
||||
|
||||
Weights = Weights,
|
||||
Layer = Layer,
|
||||
|
|
Loading…
Reference in a new issue