add tanh activation
This commit is contained in:
parent
3030e83d00
commit
9c8c1ccd0c
1 changed files with 25 additions and 0 deletions
25
nn.lua
25
nn.lua
|
@ -16,6 +16,7 @@ local print = print
|
|||
local remove = table.remove
|
||||
local sin = math.sin
|
||||
local sqrt = math.sqrt
|
||||
local tanh = math.tanh
|
||||
local tostring = tostring
|
||||
local uniform = math.random
|
||||
local unpack = table.unpack or unpack
|
||||
|
@ -293,6 +294,7 @@ local Merge = Layer:extend()
|
|||
local Relu = Layer:extend()
|
||||
local Gelu = Layer:extend()
|
||||
local Cos = Layer:extend()
|
||||
local Tanh = Layer:extend()
|
||||
local Dense = Layer:extend()
|
||||
local Softmax = Layer:extend()
|
||||
local Embed = Layer:extend()
|
||||
|
@ -515,6 +517,28 @@ function Cos:forward(X)
|
|||
return Y
|
||||
end
|
||||
|
||||
function Tanh:init()
|
||||
Layer.init(self, "Tanh")
|
||||
end
|
||||
|
||||
function Tanh:reset_cache(bs)
|
||||
self.bs = bs
|
||||
|
||||
self.cache = cache(bs, self.shape_out)
|
||||
self.dcache = cache(bs, self.shape_in)
|
||||
end
|
||||
|
||||
function Tanh:forward(X)
|
||||
local bs = checkshape(X, self.shape_in)
|
||||
if bs ~= self.bs then self:reset_cache(bs) end
|
||||
local Y = self.cache
|
||||
|
||||
for i = 1, #X do Y[i] = tanh(X[i]) end
|
||||
|
||||
checkshape(Y, self.shape_out)
|
||||
return Y
|
||||
end
|
||||
|
||||
function Dense:init(dim)
|
||||
Layer.init(self, "Dense")
|
||||
assert(type(dim) == "number")
|
||||
|
@ -813,6 +837,7 @@ return {
|
|||
Relu = Relu,
|
||||
Gelu = Gelu,
|
||||
Cos = Cos,
|
||||
Tanh = Tanh,
|
||||
Dense = Dense,
|
||||
Softmax = Softmax,
|
||||
Embed = Embed,
|
||||
|
|
Loading…
Reference in a new issue