allow Dense layers without biasing
This commit is contained in:
parent
fca4779e56
commit
7bb9c79367
1 changed files with 16 additions and 6 deletions
22
nn.lua
22
nn.lua
|
@ -554,25 +554,31 @@ function Tanh:forward(X)
|
||||||
return Y
|
return Y
|
||||||
end
|
end
|
||||||
|
|
||||||
function Dense:init(dim, norm_in)
|
function Dense:init(dim, norm_in, biasing)
|
||||||
Layer.init(self, "Dense")
|
Layer.init(self, "Dense")
|
||||||
assert(type(dim) == "number")
|
assert(type(dim) == "number")
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.shape_out = {dim}
|
self.shape_out = {dim}
|
||||||
if norm_in then
|
self.norm_in = norm_in and true or false
|
||||||
|
self.biasing = biasing == nil or biasing
|
||||||
|
|
||||||
|
if self.norm_in then
|
||||||
self.coeffs = self:_new_weights(init_normal)
|
self.coeffs = self:_new_weights(init_normal)
|
||||||
else
|
else
|
||||||
self.coeffs = self:_new_weights(init_he_normal)
|
self.coeffs = self:_new_weights(init_he_normal)
|
||||||
end
|
end
|
||||||
self.biases = self:_new_weights(init_zeros)
|
if self.biasing then
|
||||||
self.norm_in = norm_in and true or false
|
self.biases = self:_new_weights(init_zeros)
|
||||||
|
end
|
||||||
self.c = 1.0
|
self.c = 1.0
|
||||||
end
|
end
|
||||||
|
|
||||||
function Dense:make_shape(parent)
|
function Dense:make_shape(parent)
|
||||||
self.shape_in = parent.shape_out
|
self.shape_in = parent.shape_out
|
||||||
self.coeffs.shape = {self.shape_in[#self.shape_in], self.dim}
|
self.coeffs.shape = {self.shape_in[#self.shape_in], self.dim}
|
||||||
self.biases.shape = {1, self.dim}
|
if self.biasing then
|
||||||
|
self.biases.shape = {1, self.dim}
|
||||||
|
end
|
||||||
if self.norm_in then
|
if self.norm_in then
|
||||||
self.c = 1 / sqrt(prod(self.shape_in))
|
self.c = 1 / sqrt(prod(self.shape_in))
|
||||||
end
|
end
|
||||||
|
@ -584,7 +590,11 @@ function Dense:forward(X)
|
||||||
local Y = self.cache
|
local Y = self.cache
|
||||||
|
|
||||||
dot(X, self.coeffs, 2, 1, Y)
|
dot(X, self.coeffs, 2, 1, Y)
|
||||||
for i, v in ipairs(Y) do Y[i] = self.c * v + self.biases[i] end
|
if self.biasing then
|
||||||
|
for i, v in ipairs(Y) do Y[i] = self.c * v + self.biases[i] end
|
||||||
|
elseif self.norm_in then
|
||||||
|
for i, v in ipairs(Y) do Y[i] = self.c * v end
|
||||||
|
end
|
||||||
|
|
||||||
checkshape(Y, self.shape_out)
|
checkshape(Y, self.shape_out)
|
||||||
return Y
|
return Y
|
||||||
|
|
Loading…
Add table
Reference in a new issue