diff --git a/nn.lua b/nn.lua index 16e5c88..b2d0b55 100644 --- a/nn.lua +++ b/nn.lua @@ -308,6 +308,7 @@ local Input = Layer:extend() local Merge = Layer:extend() local Relu = Layer:extend() local Gelu = Layer:extend() +local Cos = Layer:extend() local Dense = Layer:extend() local Softmax = Layer:extend() local Embed = Layer:extend() @@ -508,6 +509,28 @@ function Gelu:forward(X) return Y end +function Cos:init() + Layer.init(self, "Cos") +end + +function Cos:reset_cache(bs) + self.bs = bs + + self.cache = cache(bs, self.shape_out) + self.dcache = cache(bs, self.shape_in) +end + +function Cos:forward(X) + local bs = checkshape(X, self.shape_in) + if bs ~= self.bs then self:reset_cache(bs) end + local Y = self.cache + + for i = 1, #X do Y[i] = cos(X[i]) end + + checkshape(Y, self.shape_out) + return Y +end + function Dense:init(dim) Layer.init(self, "Dense") assert(type(dim) == "number") @@ -808,6 +831,7 @@ return { Merge = Merge, Relu = Relu, Gelu = Gelu, + Cos = Cos, Dense = Dense, Softmax = Softmax, Embed = Embed,