diff --git a/main.lua b/main.lua index 76dfa4b..f00eaed 100644 --- a/main.lua +++ b/main.lua @@ -174,7 +174,7 @@ local function make_network(input_size) --]] nn_z = nn_y - nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut)) + nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut), true) nn_z = nn_z:feed(nn.Softmax()) return nn.Model({nn_x, nn_tx}, {nn_z}) end diff --git a/nn.lua b/nn.lua index 752a531..beba779 100644 --- a/nn.lua +++ b/nn.lua @@ -549,19 +549,24 @@ function Tanh:forward(X) return Y end -function Dense:init(dim) +function Dense:init(dim, norm_in) Layer.init(self, "Dense") assert(type(dim) == "number") self.dim = dim self.shape_out = {dim} self.coeffs = self:_new_weights(init_he_normal) -- should be normal, but... self.biases = self:_new_weights(init_zeros) + self.norm_in = norm_in and true or false + self.c = 1.0 end function Dense:make_shape(parent) self.shape_in = parent.shape_out self.coeffs.shape = {self.shape_in[#self.shape_in], self.dim} self.biases.shape = {1, self.dim} + if self.norm_in then + self.c = 1 / prod(self.shape_in) + end end function Dense:forward(X) @@ -570,7 +575,7 @@ function Dense:forward(X) local Y = self.cache dot(X, self.coeffs, 2, 1, Y) - for i, v in ipairs(Y) do Y[i] = v + self.biases[i] end + for i, v in ipairs(Y) do Y[i] = self.c * v + self.biases[i] end checkshape(Y, self.shape_out) return Y