remove remnants of backwards pass

This commit is contained in:
Connor Olding 2018-06-09 15:23:14 +02:00
parent f03e80b1b6
commit ae331ce60b

74
nn.lua
View file

@ -367,6 +367,7 @@ end
function Layer:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
end
function Layer:_propagate(edges, deterministic)
@ -417,12 +418,6 @@ function Merge:make_shape(parent)
self.shape_out = {self.size}
end
function Merge:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
end
function Merge:_propagate(edges, deterministic)
assert(#edges == self.shape_in)
local bs = edges[1].shape[1]
@ -445,13 +440,6 @@ function Relu:init()
Layer.init(self, "Relu")
end
function Relu:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
self.dcache = cache(bs, self.shape_in)
end
function Relu:forward(X)
local bs = checkshape(X, self.shape_in)
if bs ~= self.bs then self:reset_cache(bs) end
@ -467,27 +455,14 @@ function Gelu:init()
Layer.init(self, "Gelu")
end
function Gelu:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
self.cache_a = cache(bs, self.shape_out)
self.cache_sig = cache(bs, self.shape_out)
self.dcache = cache(bs, self.shape_in)
end
function Gelu:forward(X)
local bs = checkshape(X, self.shape_in)
if bs ~= self.bs then self:reset_cache(bs) end
local Y = self.cache
local a = self.cache_a
local sig = self.cache_sig
-- NOTE: approximate form of GELU exploiting similarities to sigmoid curve.
for i = 1, #X do
a[i] = 1.704 * X[i]
sig[i] = 1 / (1 + exp(-a[i]))
Y[i] = X[i] * sig[i]
Y[i] = X[i] / (1 + exp(-1.704 * X[i]))
end
checkshape(Y, self.shape_out)
@ -498,13 +473,6 @@ function Cos:init()
Layer.init(self, "Cos")
end
function Cos:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
self.dcache = cache(bs, self.shape_in)
end
function Cos:forward(X)
local bs = checkshape(X, self.shape_in)
if bs ~= self.bs then self:reset_cache(bs) end
@ -520,13 +488,6 @@ function Tanh:init()
Layer.init(self, "Tanh")
end
function Tanh:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
self.dcache = cache(bs, self.shape_in)
end
function Tanh:forward(X)
local bs = checkshape(X, self.shape_in)
if bs ~= self.bs then self:reset_cache(bs) end
@ -553,20 +514,11 @@ function Dense:make_shape(parent)
self.biases.shape = {1, self.dim}
end
function Dense:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
self.cache_x = cache(bs, self.shape_in)
self.dcache = cache(bs, self.shape_in)
end
function Dense:forward(X)
local bs = checkshape(X, self.shape_in)
if self.bs ~= bs then self:reset_cache(bs) end
local Y = self.cache
--dot_1aab(X, self.coeffs, Y)
dot(X, self.coeffs, 2, 1, Y)
for i = 1, self.dim do
@ -581,12 +533,6 @@ function Softmax:init()
Layer.init(self, "Softmax")
end
function Softmax:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
end
function Softmax:forward(X)
local bs = checkshape(X, self.shape_in)
if self.bs ~= bs then self:reset_cache(bs) end
@ -624,13 +570,6 @@ function Embed:make_shape(parent)
self.shape_out = {parent.shape_out[1] * self.dim}
end
function Embed:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
self.cache_x = cache(bs, self.shape_in)
end
function Embed:forward(X)
local bs = checkshape(X, self.shape_in)
if self.bs ~= bs then self:reset_cache(bs) end
@ -656,11 +595,6 @@ function LayerNorm:init(eps)
self.eps = eps
end
function LayerNorm:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
end
function LayerNorm:forward(X)
local bs = checkshape(X, self.shape_in)
if self.bs ~= bs then self:reset_cache(bs) end
@ -727,10 +661,6 @@ function Model:forward(inputs)
return outputs
end
function Model:cleargrad()
error("TODO") -- TODO
end
function Model:print()
print("digraph G {")
for _, parent in ipairs(self.nodes) do