diff --git a/nn.lua b/nn.lua index bf83e28..eac6dbf 100644 --- a/nn.lua +++ b/nn.lua @@ -16,6 +16,7 @@ local print = print local remove = table.remove local sin = math.sin local sqrt = math.sqrt +local tanh = math.tanh local tostring = tostring local uniform = math.random local unpack = table.unpack or unpack @@ -293,6 +294,7 @@ local Merge = Layer:extend() local Relu = Layer:extend() local Gelu = Layer:extend() local Cos = Layer:extend() +local Tanh = Layer:extend() local Dense = Layer:extend() local Softmax = Layer:extend() local Embed = Layer:extend() @@ -515,6 +517,28 @@ function Cos:forward(X) return Y end +function Tanh:init() + Layer.init(self, "Tanh") +end + +function Tanh:reset_cache(bs) + self.bs = bs + + self.cache = cache(bs, self.shape_out) + self.dcache = cache(bs, self.shape_in) +end + +function Tanh:forward(X) + local bs = checkshape(X, self.shape_in) + if bs ~= self.bs then self:reset_cache(bs) end + local Y = self.cache + + for i = 1, #X do Y[i] = tanh(X[i]) end + + checkshape(Y, self.shape_out) + return Y +end + function Dense:init(dim) Layer.init(self, "Dense") assert(type(dim) == "number") @@ -813,6 +837,7 @@ return { Relu = Relu, Gelu = Gelu, Cos = Cos, + Tanh = Tanh, Dense = Dense, Softmax = Softmax, Embed = Embed,