organization etc.
This commit is contained in:
parent
16b38a7151
commit
2d4ce31c7e
1 changed files with 20 additions and 10 deletions
30
nn.lua
30
nn.lua
|
@ -14,9 +14,12 @@ local sin = math.sin
|
||||||
local sqrt = math.sqrt
|
local sqrt = math.sqrt
|
||||||
local tostring = tostring
|
local tostring = tostring
|
||||||
local uniform = math.random
|
local uniform = math.random
|
||||||
|
local unpack = table.unpack or unpack
|
||||||
|
|
||||||
local Base = require("Base")
|
local Base = require("Base")
|
||||||
|
|
||||||
|
-- general utilities
|
||||||
|
|
||||||
local function copy(t) -- shallow copy
|
local function copy(t) -- shallow copy
|
||||||
local new_t = {}
|
local new_t = {}
|
||||||
for k, v in pairs(t) do new_t[k] = v end
|
for k, v in pairs(t) do new_t[k] = v end
|
||||||
|
@ -33,6 +36,8 @@ local function contains(t, a)
|
||||||
return indexof(t, a) ~= nil
|
return indexof(t, a) ~= nil
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- math utilities
|
||||||
|
|
||||||
local function prod(x, ...)
|
local function prod(x, ...)
|
||||||
if type(x) == "table" then
|
if type(x) == "table" then
|
||||||
return prod(unpack(x))
|
return prod(unpack(x))
|
||||||
|
@ -54,6 +59,16 @@ local function zeros(n, out)
|
||||||
return out
|
return out
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function allocate(t, out, init)
|
||||||
|
out = out or {}
|
||||||
|
local size = t
|
||||||
|
if init ~= nil then
|
||||||
|
return init(zeros(size, out))
|
||||||
|
else
|
||||||
|
return zeros(size, out)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
local function init_zeros(t, fan_in, fan_out)
|
local function init_zeros(t, fan_in, fan_out)
|
||||||
for i = 1, #t do t[i] = 0 end
|
for i = 1, #t do t[i] = 0 end
|
||||||
return t
|
return t
|
||||||
|
@ -71,15 +86,7 @@ local function init_he_normal(t, fan_in, fan_out)
|
||||||
return t
|
return t
|
||||||
end
|
end
|
||||||
|
|
||||||
local function allocate(t, out, init)
|
-- nodal
|
||||||
out = out or {}
|
|
||||||
local size = t
|
|
||||||
if init ~= nil then
|
|
||||||
return init(zeros(size, out))
|
|
||||||
else
|
|
||||||
return zeros(size, out)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
local function traverse(node_in, node_out, nodes, dummy_mode)
|
local function traverse(node_in, node_out, nodes, dummy_mode)
|
||||||
-- i have no idea if this is any algorithm in particular.
|
-- i have no idea if this is any algorithm in particular.
|
||||||
|
@ -127,6 +134,8 @@ local function traverse_all(nodes_in, nodes_out, nodes)
|
||||||
return traverse(all_in, all_out, nodes or {}, true)
|
return traverse(all_in, all_out, nodes or {}, true)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- classes
|
||||||
|
|
||||||
local Weights = Base:extend()
|
local Weights = Base:extend()
|
||||||
local Layer = Base:extend()
|
local Layer = Base:extend()
|
||||||
local Model = Base:extend()
|
local Model = Base:extend()
|
||||||
|
@ -457,11 +466,12 @@ return {
|
||||||
prod = prod,
|
prod = prod,
|
||||||
normal = normal,
|
normal = normal,
|
||||||
zeros = zeros,
|
zeros = zeros,
|
||||||
|
allocate = allocate,
|
||||||
init_zeros = init_zeros,
|
init_zeros = init_zeros,
|
||||||
init_he_uniform = init_he_uniform,
|
init_he_uniform = init_he_uniform,
|
||||||
init_he_normal = init_he_normal,
|
init_he_normal = init_he_normal,
|
||||||
allocate = allocate,
|
|
||||||
traverse = traverse,
|
traverse = traverse,
|
||||||
|
traverse_all = traverse_all,
|
||||||
|
|
||||||
Weights = Weights,
|
Weights = Weights,
|
||||||
Layer = Layer,
|
Layer = Layer,
|
||||||
|
|
Loading…
Add table
Reference in a new issue