diff --git a/nn.lua b/nn.lua index beba779..a460d46 100644 --- a/nn.lua +++ b/nn.lua @@ -86,6 +86,11 @@ local function init_uniform(t, fan_in, fan_out) return t end +local function init_normal(t, fan_in, fan_out) + for i = 1, #t do t[i] = normal() end + return t +end + local function init_he_uniform(t, fan_in, fan_out) local s = sqrt(6 / fan_in) for i = 1, #t do t[i] = (uniform() * 2 - 1) * s end @@ -554,7 +559,11 @@ function Dense:init(dim, norm_in) assert(type(dim) == "number") self.dim = dim self.shape_out = {dim} - self.coeffs = self:_new_weights(init_he_normal) -- should be normal, but... + if norm_in then + self.coeffs = self:_new_weights(init_normal) + else + self.coeffs = self:_new_weights(init_he_normal) + end self.biases = self:_new_weights(init_zeros) self.norm_in = norm_in and true or false self.c = 1.0 @@ -565,7 +574,7 @@ function Dense:make_shape(parent) self.coeffs.shape = {self.shape_in[#self.shape_in], self.dim} self.biases.shape = {1, self.dim} if self.norm_in then - self.c = 1 / prod(self.shape_in) + self.c = 1 / sqrt(prod(self.shape_in)) end end @@ -642,7 +651,7 @@ function Embed:init(vocab, dim) assert(type(dim) == "number") self.vocab = vocab self.dim = dim - self.lut = self:_new_weights(init_uniform) + self.lut = self:_new_weights(init_normal) self.lut.shape = {self.vocab, self.dim} end