From 18e4376aaeb6e2deba9fbb33cadede0826ce329b Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Sun, 24 Jun 2018 12:18:30 +0200 Subject: [PATCH] add normalizing and no-biasing features to DenseBroadcast --- nn.lua | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/nn.lua b/nn.lua index 862d8c6..ecb085e 100644 --- a/nn.lua +++ b/nn.lua @@ -600,13 +600,21 @@ function Dense:forward(X) return Y end -function DenseBroadcast:init(dim) +function DenseBroadcast:init(dim, norm_in, biasing) -- same as Dense but applies the same to every m of (m, n). Layer.init(self, "DenseBroadcast") assert(type(dim) == "number") self.dim = dim - self.coeffs = self:_new_weights(init_he_normal) -- should be normal, but... - self.biases = self:_new_weights(init_zeros) + self.norm_in = norm_in and true or false + if self.norm_in then + self.coeffs = self:_new_weights(init_normal) + else + self.coeffs = self:_new_weights(init_he_normal) + end + if self.biasing then + self.biases = self:_new_weights(init_zeros) + end + self.c = 1.0 end function DenseBroadcast:make_shape(parent) @@ -614,7 +622,12 @@ function DenseBroadcast:make_shape(parent) assert(#self.shape_in == 2) self.shape_out = {self.shape_in[1], self.dim} self.coeffs.shape = {self.shape_in[#self.shape_in], self.dim} - self.biases.shape = {1, self.dim} + if self.biasing then + self.biases.shape = {1, self.dim} + end + if self.norm_in then + self.c = 1 / sqrt(prod(self.shape_in)) + end end function DenseBroadcast:forward(X) @@ -623,7 +636,11 @@ function DenseBroadcast:forward(X) local Y = self.cache dot(X, self.coeffs, 3, 1, Y) - for i, v in ipairs(Y) do Y[i] = v + self.biases[(i - 1) % self.dim + 1] end + if self.biasing then + for i, v in ipairs(Y) do Y[i] = self.c * v + self.biases[(i - 1) % self.dim + 1] end + elseif self.norm_in then + for i, v in ipairs(Y) do Y[i] = self.c * v end + end checkshape(Y, self.shape_out) return Y