remove remnants of backwards pass
This commit is contained in:
parent
f03e80b1b6
commit
ae331ce60b
1 changed files with 2 additions and 72 deletions
74
nn.lua
74
nn.lua
|
@ -367,6 +367,7 @@ end
|
||||||
|
|
||||||
function Layer:reset_cache(bs)
|
function Layer:reset_cache(bs)
|
||||||
self.bs = bs
|
self.bs = bs
|
||||||
|
self.cache = cache(bs, self.shape_out)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Layer:_propagate(edges, deterministic)
|
function Layer:_propagate(edges, deterministic)
|
||||||
|
@ -417,12 +418,6 @@ function Merge:make_shape(parent)
|
||||||
self.shape_out = {self.size}
|
self.shape_out = {self.size}
|
||||||
end
|
end
|
||||||
|
|
||||||
function Merge:reset_cache(bs)
|
|
||||||
self.bs = bs
|
|
||||||
|
|
||||||
self.cache = cache(bs, self.shape_out)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Merge:_propagate(edges, deterministic)
|
function Merge:_propagate(edges, deterministic)
|
||||||
assert(#edges == self.shape_in)
|
assert(#edges == self.shape_in)
|
||||||
local bs = edges[1].shape[1]
|
local bs = edges[1].shape[1]
|
||||||
|
@ -445,13 +440,6 @@ function Relu:init()
|
||||||
Layer.init(self, "Relu")
|
Layer.init(self, "Relu")
|
||||||
end
|
end
|
||||||
|
|
||||||
function Relu:reset_cache(bs)
|
|
||||||
self.bs = bs
|
|
||||||
|
|
||||||
self.cache = cache(bs, self.shape_out)
|
|
||||||
self.dcache = cache(bs, self.shape_in)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Relu:forward(X)
|
function Relu:forward(X)
|
||||||
local bs = checkshape(X, self.shape_in)
|
local bs = checkshape(X, self.shape_in)
|
||||||
if bs ~= self.bs then self:reset_cache(bs) end
|
if bs ~= self.bs then self:reset_cache(bs) end
|
||||||
|
@ -467,27 +455,14 @@ function Gelu:init()
|
||||||
Layer.init(self, "Gelu")
|
Layer.init(self, "Gelu")
|
||||||
end
|
end
|
||||||
|
|
||||||
function Gelu:reset_cache(bs)
|
|
||||||
self.bs = bs
|
|
||||||
|
|
||||||
self.cache = cache(bs, self.shape_out)
|
|
||||||
self.cache_a = cache(bs, self.shape_out)
|
|
||||||
self.cache_sig = cache(bs, self.shape_out)
|
|
||||||
self.dcache = cache(bs, self.shape_in)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Gelu:forward(X)
|
function Gelu:forward(X)
|
||||||
local bs = checkshape(X, self.shape_in)
|
local bs = checkshape(X, self.shape_in)
|
||||||
if bs ~= self.bs then self:reset_cache(bs) end
|
if bs ~= self.bs then self:reset_cache(bs) end
|
||||||
local Y = self.cache
|
local Y = self.cache
|
||||||
local a = self.cache_a
|
|
||||||
local sig = self.cache_sig
|
|
||||||
|
|
||||||
-- NOTE: approximate form of GELU exploiting similarities to sigmoid curve.
|
-- NOTE: approximate form of GELU exploiting similarities to sigmoid curve.
|
||||||
for i = 1, #X do
|
for i = 1, #X do
|
||||||
a[i] = 1.704 * X[i]
|
Y[i] = X[i] / (1 + exp(-1.704 * X[i]))
|
||||||
sig[i] = 1 / (1 + exp(-a[i]))
|
|
||||||
Y[i] = X[i] * sig[i]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
checkshape(Y, self.shape_out)
|
checkshape(Y, self.shape_out)
|
||||||
|
@ -498,13 +473,6 @@ function Cos:init()
|
||||||
Layer.init(self, "Cos")
|
Layer.init(self, "Cos")
|
||||||
end
|
end
|
||||||
|
|
||||||
function Cos:reset_cache(bs)
|
|
||||||
self.bs = bs
|
|
||||||
|
|
||||||
self.cache = cache(bs, self.shape_out)
|
|
||||||
self.dcache = cache(bs, self.shape_in)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Cos:forward(X)
|
function Cos:forward(X)
|
||||||
local bs = checkshape(X, self.shape_in)
|
local bs = checkshape(X, self.shape_in)
|
||||||
if bs ~= self.bs then self:reset_cache(bs) end
|
if bs ~= self.bs then self:reset_cache(bs) end
|
||||||
|
@ -520,13 +488,6 @@ function Tanh:init()
|
||||||
Layer.init(self, "Tanh")
|
Layer.init(self, "Tanh")
|
||||||
end
|
end
|
||||||
|
|
||||||
function Tanh:reset_cache(bs)
|
|
||||||
self.bs = bs
|
|
||||||
|
|
||||||
self.cache = cache(bs, self.shape_out)
|
|
||||||
self.dcache = cache(bs, self.shape_in)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Tanh:forward(X)
|
function Tanh:forward(X)
|
||||||
local bs = checkshape(X, self.shape_in)
|
local bs = checkshape(X, self.shape_in)
|
||||||
if bs ~= self.bs then self:reset_cache(bs) end
|
if bs ~= self.bs then self:reset_cache(bs) end
|
||||||
|
@ -553,20 +514,11 @@ function Dense:make_shape(parent)
|
||||||
self.biases.shape = {1, self.dim}
|
self.biases.shape = {1, self.dim}
|
||||||
end
|
end
|
||||||
|
|
||||||
function Dense:reset_cache(bs)
|
|
||||||
self.bs = bs
|
|
||||||
|
|
||||||
self.cache = cache(bs, self.shape_out)
|
|
||||||
self.cache_x = cache(bs, self.shape_in)
|
|
||||||
self.dcache = cache(bs, self.shape_in)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Dense:forward(X)
|
function Dense:forward(X)
|
||||||
local bs = checkshape(X, self.shape_in)
|
local bs = checkshape(X, self.shape_in)
|
||||||
if self.bs ~= bs then self:reset_cache(bs) end
|
if self.bs ~= bs then self:reset_cache(bs) end
|
||||||
local Y = self.cache
|
local Y = self.cache
|
||||||
|
|
||||||
--dot_1aab(X, self.coeffs, Y)
|
|
||||||
dot(X, self.coeffs, 2, 1, Y)
|
dot(X, self.coeffs, 2, 1, Y)
|
||||||
|
|
||||||
for i = 1, self.dim do
|
for i = 1, self.dim do
|
||||||
|
@ -581,12 +533,6 @@ function Softmax:init()
|
||||||
Layer.init(self, "Softmax")
|
Layer.init(self, "Softmax")
|
||||||
end
|
end
|
||||||
|
|
||||||
function Softmax:reset_cache(bs)
|
|
||||||
self.bs = bs
|
|
||||||
|
|
||||||
self.cache = cache(bs, self.shape_out)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Softmax:forward(X)
|
function Softmax:forward(X)
|
||||||
local bs = checkshape(X, self.shape_in)
|
local bs = checkshape(X, self.shape_in)
|
||||||
if self.bs ~= bs then self:reset_cache(bs) end
|
if self.bs ~= bs then self:reset_cache(bs) end
|
||||||
|
@ -624,13 +570,6 @@ function Embed:make_shape(parent)
|
||||||
self.shape_out = {parent.shape_out[1] * self.dim}
|
self.shape_out = {parent.shape_out[1] * self.dim}
|
||||||
end
|
end
|
||||||
|
|
||||||
function Embed:reset_cache(bs)
|
|
||||||
self.bs = bs
|
|
||||||
|
|
||||||
self.cache = cache(bs, self.shape_out)
|
|
||||||
self.cache_x = cache(bs, self.shape_in)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Embed:forward(X)
|
function Embed:forward(X)
|
||||||
local bs = checkshape(X, self.shape_in)
|
local bs = checkshape(X, self.shape_in)
|
||||||
if self.bs ~= bs then self:reset_cache(bs) end
|
if self.bs ~= bs then self:reset_cache(bs) end
|
||||||
|
@ -656,11 +595,6 @@ function LayerNorm:init(eps)
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
end
|
end
|
||||||
|
|
||||||
function LayerNorm:reset_cache(bs)
|
|
||||||
self.bs = bs
|
|
||||||
self.cache = cache(bs, self.shape_out)
|
|
||||||
end
|
|
||||||
|
|
||||||
function LayerNorm:forward(X)
|
function LayerNorm:forward(X)
|
||||||
local bs = checkshape(X, self.shape_in)
|
local bs = checkshape(X, self.shape_in)
|
||||||
if self.bs ~= bs then self:reset_cache(bs) end
|
if self.bs ~= bs then self:reset_cache(bs) end
|
||||||
|
@ -727,10 +661,6 @@ function Model:forward(inputs)
|
||||||
return outputs
|
return outputs
|
||||||
end
|
end
|
||||||
|
|
||||||
function Model:cleargrad()
|
|
||||||
error("TODO") -- TODO
|
|
||||||
end
|
|
||||||
|
|
||||||
function Model:print()
|
function Model:print()
|
||||||
print("digraph G {")
|
print("digraph G {")
|
||||||
for _, parent in ipairs(self.nodes) do
|
for _, parent in ipairs(self.nodes) do
|
||||||
|
|
Loading…
Add table
Reference in a new issue