diff --git a/nn.lua b/nn.lua index 3b5f94f..652f4c1 100644 --- a/nn.lua +++ b/nn.lua @@ -59,6 +59,12 @@ local function zeros(n, out) return out end +local function arange(n, out) + out = out or {} + for i = 1, n do out[i] = i - 1 end + return out +end + local function allocate(t, out, init) out = out or {} local size = t @@ -353,6 +359,10 @@ function Softmax:forward(X) return Y end +--function Softmax:backward(dY) + --return (dY - np.sum(dY * self.sm, axis=-1, keepdims=True)) * self.cache +--end + function Model:init(nodes_in, nodes_out) assert(#nodes_in > 0, #nodes_in) assert(#nodes_out > 0, #nodes_out) @@ -393,6 +403,14 @@ function Model:forward(inputs) return outputs end +function Model:cleargrad() + error("TODO") -- TODO +end + +function Model:print() + error("TODO") -- TODO +end + function Model:collect() -- return a flat array of all the weights in the graph. -- if Lua had slices, we wouldn't need this. future library idea? @@ -468,6 +486,7 @@ return { prod = prod, normal = normal, zeros = zeros, + arange = arange, allocate = allocate, init_zeros = init_zeros, init_he_uniform = init_he_uniform,