diff --git a/nn.lua b/nn.lua index a460d46..862d8c6 100644 --- a/nn.lua +++ b/nn.lua @@ -554,25 +554,31 @@ function Tanh:forward(X) return Y end -function Dense:init(dim, norm_in) +function Dense:init(dim, norm_in, biasing) Layer.init(self, "Dense") assert(type(dim) == "number") self.dim = dim self.shape_out = {dim} - if norm_in then + self.norm_in = norm_in and true or false + self.biasing = biasing == nil or biasing + + if self.norm_in then self.coeffs = self:_new_weights(init_normal) else self.coeffs = self:_new_weights(init_he_normal) end - self.biases = self:_new_weights(init_zeros) - self.norm_in = norm_in and true or false + if self.biasing then + self.biases = self:_new_weights(init_zeros) + end self.c = 1.0 end function Dense:make_shape(parent) self.shape_in = parent.shape_out self.coeffs.shape = {self.shape_in[#self.shape_in], self.dim} - self.biases.shape = {1, self.dim} + if self.biasing then + self.biases.shape = {1, self.dim} + end if self.norm_in then self.c = 1 / sqrt(prod(self.shape_in)) end @@ -584,7 +590,11 @@ function Dense:forward(X) local Y = self.cache dot(X, self.coeffs, 2, 1, Y) - for i, v in ipairs(Y) do Y[i] = self.c * v + self.biases[i] end + if self.biasing then + for i, v in ipairs(Y) do Y[i] = self.c * v + self.biases[i] end + elseif self.norm_in then + for i, v in ipairs(Y) do Y[i] = self.c * v end + end checkshape(Y, self.shape_out) return Y