From 7cecd57d057b341c1a3e2d6e776d693ecf6c6866 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Tue, 12 Jun 2018 23:39:13 +0200 Subject: [PATCH] tweak inits and norm_in for variances of 1 --- nn.lua | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/nn.lua b/nn.lua index beba779..a460d46 100644 --- a/nn.lua +++ b/nn.lua @@ -86,6 +86,11 @@ local function init_uniform(t, fan_in, fan_out) return t end +local function init_normal(t, fan_in, fan_out) + for i = 1, #t do t[i] = normal() end + return t +end + local function init_he_uniform(t, fan_in, fan_out) local s = sqrt(6 / fan_in) for i = 1, #t do t[i] = (uniform() * 2 - 1) * s end @@ -554,7 +559,11 @@ function Dense:init(dim, norm_in) assert(type(dim) == "number") self.dim = dim self.shape_out = {dim} - self.coeffs = self:_new_weights(init_he_normal) -- should be normal, but... + if norm_in then + self.coeffs = self:_new_weights(init_normal) + else + self.coeffs = self:_new_weights(init_he_normal) + end self.biases = self:_new_weights(init_zeros) self.norm_in = norm_in and true or false self.c = 1.0 @@ -565,7 +574,7 @@ function Dense:make_shape(parent) self.coeffs.shape = {self.shape_in[#self.shape_in], self.dim} self.biases.shape = {1, self.dim} if self.norm_in then - self.c = 1 / prod(self.shape_in) + self.c = 1 / sqrt(prod(self.shape_in)) end end @@ -642,7 +651,7 @@ function Embed:init(vocab, dim) assert(type(dim) == "number") self.vocab = vocab self.dim = dim - self.lut = self:_new_weights(init_uniform) + self.lut = self:_new_weights(init_normal) self.lut.shape = {self.vocab, self.dim} end