looking forwards
This commit is contained in:
parent
db603753f4
commit
3e3b4d9207
1 changed files with 19 additions and 0 deletions
19
nn.lua
19
nn.lua
|
@ -59,6 +59,12 @@ local function zeros(n, out)
|
|||
return out
|
||||
end
|
||||
|
||||
local function arange(n, out)
|
||||
out = out or {}
|
||||
for i = 1, n do out[i] = i - 1 end
|
||||
return out
|
||||
end
|
||||
|
||||
local function allocate(t, out, init)
|
||||
out = out or {}
|
||||
local size = t
|
||||
|
@ -353,6 +359,10 @@ function Softmax:forward(X)
|
|||
return Y
|
||||
end
|
||||
|
||||
--function Softmax:backward(dY)
|
||||
--return (dY - np.sum(dY * self.sm, axis=-1, keepdims=True)) * self.cache
|
||||
--end
|
||||
|
||||
function Model:init(nodes_in, nodes_out)
|
||||
assert(#nodes_in > 0, #nodes_in)
|
||||
assert(#nodes_out > 0, #nodes_out)
|
||||
|
@ -393,6 +403,14 @@ function Model:forward(inputs)
|
|||
return outputs
|
||||
end
|
||||
|
||||
function Model:cleargrad()
|
||||
error("TODO") -- TODO
|
||||
end
|
||||
|
||||
function Model:print()
|
||||
error("TODO") -- TODO
|
||||
end
|
||||
|
||||
function Model:collect()
|
||||
-- return a flat array of all the weights in the graph.
|
||||
-- if Lua had slices, we wouldn't need this. future library idea?
|
||||
|
@ -468,6 +486,7 @@ return {
|
|||
prod = prod,
|
||||
normal = normal,
|
||||
zeros = zeros,
|
||||
arange = arange,
|
||||
allocate = allocate,
|
||||
init_zeros = init_zeros,
|
||||
init_he_uniform = init_he_uniform,
|
||||
|
|
Loading…
Reference in a new issue