tweak inits and norm_in for variances of 1

This commit is contained in:
Connor Olding 2018-06-12 23:39:13 +02:00
parent 74eb2bfbef
commit 7cecd57d05

15
nn.lua
View File

@ -86,6 +86,11 @@ local function init_uniform(t, fan_in, fan_out)
return t
end
local function init_normal(t, fan_in, fan_out)
for i = 1, #t do t[i] = normal() end
return t
end
local function init_he_uniform(t, fan_in, fan_out)
local s = sqrt(6 / fan_in)
for i = 1, #t do t[i] = (uniform() * 2 - 1) * s end
@ -554,7 +559,11 @@ function Dense:init(dim, norm_in)
assert(type(dim) == "number")
self.dim = dim
self.shape_out = {dim}
self.coeffs = self:_new_weights(init_he_normal) -- should be normal, but...
if norm_in then
self.coeffs = self:_new_weights(init_normal)
else
self.coeffs = self:_new_weights(init_he_normal)
end
self.biases = self:_new_weights(init_zeros)
self.norm_in = norm_in and true or false
self.c = 1.0
@ -565,7 +574,7 @@ function Dense:make_shape(parent)
self.coeffs.shape = {self.shape_in[#self.shape_in], self.dim}
self.biases.shape = {1, self.dim}
if self.norm_in then
self.c = 1 / prod(self.shape_in)
self.c = 1 / sqrt(prod(self.shape_in))
end
end
@ -642,7 +651,7 @@ function Embed:init(vocab, dim)
assert(type(dim) == "number")
self.vocab = vocab
self.dim = dim
self.lut = self:_new_weights(init_uniform)
self.lut = self:_new_weights(init_normal)
self.lut.shape = {self.vocab, self.dim}
end