tweak inits and norm_in for variances of 1
This commit is contained in:
parent
74eb2bfbef
commit
7cecd57d05
1 changed files with 12 additions and 3 deletions
15
nn.lua
15
nn.lua
|
@ -86,6 +86,11 @@ local function init_uniform(t, fan_in, fan_out)
|
|||
return t
|
||||
end
|
||||
|
||||
local function init_normal(t, fan_in, fan_out)
|
||||
for i = 1, #t do t[i] = normal() end
|
||||
return t
|
||||
end
|
||||
|
||||
local function init_he_uniform(t, fan_in, fan_out)
|
||||
local s = sqrt(6 / fan_in)
|
||||
for i = 1, #t do t[i] = (uniform() * 2 - 1) * s end
|
||||
|
@ -554,7 +559,11 @@ function Dense:init(dim, norm_in)
|
|||
assert(type(dim) == "number")
|
||||
self.dim = dim
|
||||
self.shape_out = {dim}
|
||||
self.coeffs = self:_new_weights(init_he_normal) -- should be normal, but...
|
||||
if norm_in then
|
||||
self.coeffs = self:_new_weights(init_normal)
|
||||
else
|
||||
self.coeffs = self:_new_weights(init_he_normal)
|
||||
end
|
||||
self.biases = self:_new_weights(init_zeros)
|
||||
self.norm_in = norm_in and true or false
|
||||
self.c = 1.0
|
||||
|
@ -565,7 +574,7 @@ function Dense:make_shape(parent)
|
|||
self.coeffs.shape = {self.shape_in[#self.shape_in], self.dim}
|
||||
self.biases.shape = {1, self.dim}
|
||||
if self.norm_in then
|
||||
self.c = 1 / prod(self.shape_in)
|
||||
self.c = 1 / sqrt(prod(self.shape_in))
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -642,7 +651,7 @@ function Embed:init(vocab, dim)
|
|||
assert(type(dim) == "number")
|
||||
self.vocab = vocab
|
||||
self.dim = dim
|
||||
self.lut = self:_new_weights(init_uniform)
|
||||
self.lut = self:_new_weights(init_normal)
|
||||
self.lut.shape = {self.vocab, self.dim}
|
||||
end
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue