add tanh activation
This commit is contained in:
parent
3030e83d00
commit
9c8c1ccd0c
1 changed files with 25 additions and 0 deletions
25
nn.lua
25
nn.lua
|
@ -16,6 +16,7 @@ local print = print
|
||||||
local remove = table.remove
|
local remove = table.remove
|
||||||
local sin = math.sin
|
local sin = math.sin
|
||||||
local sqrt = math.sqrt
|
local sqrt = math.sqrt
|
||||||
|
local tanh = math.tanh
|
||||||
local tostring = tostring
|
local tostring = tostring
|
||||||
local uniform = math.random
|
local uniform = math.random
|
||||||
local unpack = table.unpack or unpack
|
local unpack = table.unpack or unpack
|
||||||
|
@ -293,6 +294,7 @@ local Merge = Layer:extend()
|
||||||
local Relu = Layer:extend()
|
local Relu = Layer:extend()
|
||||||
local Gelu = Layer:extend()
|
local Gelu = Layer:extend()
|
||||||
local Cos = Layer:extend()
|
local Cos = Layer:extend()
|
||||||
|
local Tanh = Layer:extend()
|
||||||
local Dense = Layer:extend()
|
local Dense = Layer:extend()
|
||||||
local Softmax = Layer:extend()
|
local Softmax = Layer:extend()
|
||||||
local Embed = Layer:extend()
|
local Embed = Layer:extend()
|
||||||
|
@ -515,6 +517,28 @@ function Cos:forward(X)
|
||||||
return Y
|
return Y
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Tanh:init()
|
||||||
|
Layer.init(self, "Tanh")
|
||||||
|
end
|
||||||
|
|
||||||
|
function Tanh:reset_cache(bs)
|
||||||
|
self.bs = bs
|
||||||
|
|
||||||
|
self.cache = cache(bs, self.shape_out)
|
||||||
|
self.dcache = cache(bs, self.shape_in)
|
||||||
|
end
|
||||||
|
|
||||||
|
function Tanh:forward(X)
|
||||||
|
local bs = checkshape(X, self.shape_in)
|
||||||
|
if bs ~= self.bs then self:reset_cache(bs) end
|
||||||
|
local Y = self.cache
|
||||||
|
|
||||||
|
for i = 1, #X do Y[i] = tanh(X[i]) end
|
||||||
|
|
||||||
|
checkshape(Y, self.shape_out)
|
||||||
|
return Y
|
||||||
|
end
|
||||||
|
|
||||||
function Dense:init(dim)
|
function Dense:init(dim)
|
||||||
Layer.init(self, "Dense")
|
Layer.init(self, "Dense")
|
||||||
assert(type(dim) == "number")
|
assert(type(dim) == "number")
|
||||||
|
@ -813,6 +837,7 @@ return {
|
||||||
Relu = Relu,
|
Relu = Relu,
|
||||||
Gelu = Gelu,
|
Gelu = Gelu,
|
||||||
Cos = Cos,
|
Cos = Cos,
|
||||||
|
Tanh = Tanh,
|
||||||
Dense = Dense,
|
Dense = Dense,
|
||||||
Softmax = Softmax,
|
Softmax = Softmax,
|
||||||
Embed = Embed,
|
Embed = Embed,
|
||||||
|
|
Loading…
Reference in a new issue