add tanh activation

This commit is contained in:
Connor Olding 2018-05-14 08:27:11 +02:00
parent 3030e83d00
commit 9c8c1ccd0c

25
nn.lua
View File

@ -16,6 +16,7 @@ local print = print
local remove = table.remove
local sin = math.sin
local sqrt = math.sqrt
local tanh = math.tanh
local tostring = tostring
local uniform = math.random
local unpack = table.unpack or unpack
@ -293,6 +294,7 @@ local Merge = Layer:extend()
local Relu = Layer:extend()
local Gelu = Layer:extend()
local Cos = Layer:extend()
local Tanh = Layer:extend()
local Dense = Layer:extend()
local Softmax = Layer:extend()
local Embed = Layer:extend()
@ -515,6 +517,28 @@ function Cos:forward(X)
return Y
end
function Tanh:init()
Layer.init(self, "Tanh")
end
function Tanh:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
self.dcache = cache(bs, self.shape_in)
end
function Tanh:forward(X)
local bs = checkshape(X, self.shape_in)
if bs ~= self.bs then self:reset_cache(bs) end
local Y = self.cache
for i = 1, #X do Y[i] = tanh(X[i]) end
checkshape(Y, self.shape_out)
return Y
end
function Dense:init(dim)
Layer.init(self, "Dense")
assert(type(dim) == "number")
@ -813,6 +837,7 @@ return {
Relu = Relu,
Gelu = Gelu,
Cos = Cos,
Tanh = Tanh,
Dense = Dense,
Softmax = Softmax,
Embed = Embed,