diff --git a/nn.lua b/nn.lua index 3faef88..752a531 100644 --- a/nn.lua +++ b/nn.lua @@ -674,6 +674,7 @@ end function LayerNorm:forward(X) local bs = checkshape(X, self.shape_in) if self.bs ~= bs then self:reset_cache(bs) end + local Y = self.cache local mean = 0 for i, v in ipairs(X) do @@ -683,16 +684,16 @@ function LayerNorm:forward(X) local var = 0 for i, v in ipairs(X) do local delta = v - mean - self.cache[i] = delta + Y[i] = delta var = var + delta * delta / #X end local std = sqrt(var + self.eps) - for i, v in ipairs(self.cache) do - self.cache[i] = v / std + for i, v in ipairs(Y) do + Y[i] = v / std end - return self.cache + return Y end function Model:init(nodes_in, nodes_out)