From ae331ce60bfa6ecc3adbf72a52d788317000b813 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Sat, 9 Jun 2018 15:23:14 +0200 Subject: [PATCH] remove remnants of backwards pass --- nn.lua | 74 ++-------------------------------------------------------- 1 file changed, 2 insertions(+), 72 deletions(-) diff --git a/nn.lua b/nn.lua index bb3e828..edd9391 100644 --- a/nn.lua +++ b/nn.lua @@ -367,6 +367,7 @@ end function Layer:reset_cache(bs) self.bs = bs + self.cache = cache(bs, self.shape_out) end function Layer:_propagate(edges, deterministic) @@ -417,12 +418,6 @@ function Merge:make_shape(parent) self.shape_out = {self.size} end -function Merge:reset_cache(bs) - self.bs = bs - - self.cache = cache(bs, self.shape_out) -end - function Merge:_propagate(edges, deterministic) assert(#edges == self.shape_in) local bs = edges[1].shape[1] @@ -445,13 +440,6 @@ function Relu:init() Layer.init(self, "Relu") end -function Relu:reset_cache(bs) - self.bs = bs - - self.cache = cache(bs, self.shape_out) - self.dcache = cache(bs, self.shape_in) -end - function Relu:forward(X) local bs = checkshape(X, self.shape_in) if bs ~= self.bs then self:reset_cache(bs) end @@ -467,27 +455,14 @@ function Gelu:init() Layer.init(self, "Gelu") end -function Gelu:reset_cache(bs) - self.bs = bs - - self.cache = cache(bs, self.shape_out) - self.cache_a = cache(bs, self.shape_out) - self.cache_sig = cache(bs, self.shape_out) - self.dcache = cache(bs, self.shape_in) -end - function Gelu:forward(X) local bs = checkshape(X, self.shape_in) if bs ~= self.bs then self:reset_cache(bs) end local Y = self.cache - local a = self.cache_a - local sig = self.cache_sig -- NOTE: approximate form of GELU exploiting similarities to sigmoid curve. for i = 1, #X do - a[i] = 1.704 * X[i] - sig[i] = 1 / (1 + exp(-a[i])) - Y[i] = X[i] * sig[i] + Y[i] = X[i] / (1 + exp(-1.704 * X[i])) end checkshape(Y, self.shape_out) @@ -498,13 +473,6 @@ function Cos:init() Layer.init(self, "Cos") end -function Cos:reset_cache(bs) - self.bs = bs - - self.cache = cache(bs, self.shape_out) - self.dcache = cache(bs, self.shape_in) -end - function Cos:forward(X) local bs = checkshape(X, self.shape_in) if bs ~= self.bs then self:reset_cache(bs) end @@ -520,13 +488,6 @@ function Tanh:init() Layer.init(self, "Tanh") end -function Tanh:reset_cache(bs) - self.bs = bs - - self.cache = cache(bs, self.shape_out) - self.dcache = cache(bs, self.shape_in) -end - function Tanh:forward(X) local bs = checkshape(X, self.shape_in) if bs ~= self.bs then self:reset_cache(bs) end @@ -553,20 +514,11 @@ function Dense:make_shape(parent) self.biases.shape = {1, self.dim} end -function Dense:reset_cache(bs) - self.bs = bs - - self.cache = cache(bs, self.shape_out) - self.cache_x = cache(bs, self.shape_in) - self.dcache = cache(bs, self.shape_in) -end - function Dense:forward(X) local bs = checkshape(X, self.shape_in) if self.bs ~= bs then self:reset_cache(bs) end local Y = self.cache - --dot_1aab(X, self.coeffs, Y) dot(X, self.coeffs, 2, 1, Y) for i = 1, self.dim do @@ -581,12 +533,6 @@ function Softmax:init() Layer.init(self, "Softmax") end -function Softmax:reset_cache(bs) - self.bs = bs - - self.cache = cache(bs, self.shape_out) -end - function Softmax:forward(X) local bs = checkshape(X, self.shape_in) if self.bs ~= bs then self:reset_cache(bs) end @@ -624,13 +570,6 @@ function Embed:make_shape(parent) self.shape_out = {parent.shape_out[1] * self.dim} end -function Embed:reset_cache(bs) - self.bs = bs - - self.cache = cache(bs, self.shape_out) - self.cache_x = cache(bs, self.shape_in) -end - function Embed:forward(X) local bs = checkshape(X, self.shape_in) if self.bs ~= bs then self:reset_cache(bs) end @@ -656,11 +595,6 @@ function LayerNorm:init(eps) self.eps = eps end -function LayerNorm:reset_cache(bs) - self.bs = bs - self.cache = cache(bs, self.shape_out) -end - function LayerNorm:forward(X) local bs = checkshape(X, self.shape_in) if self.bs ~= bs then self:reset_cache(bs) end @@ -727,10 +661,6 @@ function Model:forward(inputs) return outputs end -function Model:cleargrad() - error("TODO") -- TODO -end - function Model:print() print("digraph G {") for _, parent in ipairs(self.nodes) do