diff --git a/nn.lua b/nn.lua index edd9391..467bc2a 100644 --- a/nn.lua +++ b/nn.lua @@ -290,11 +290,13 @@ local Layer = Base:extend() local Model = Base:extend() local Input = Layer:extend() local Merge = Layer:extend() +local Reshape = Layer:extend() local Relu = Layer:extend() local Gelu = Layer:extend() local Cos = Layer:extend() local Tanh = Layer:extend() local Dense = Layer:extend() +local DenseBroadcast = Layer:extend() local Softmax = Layer:extend() local Embed = Layer:extend() local LayerNorm = Layer:extend() @@ -436,6 +438,29 @@ function Merge:_propagate(edges, deterministic) return Y end +function Reshape:init(shape) + Layer.init(self, "Reshape") + self.size = 0 + self.shape_out = shape +end + +function Reshape:make_shape(parent) + self.shape_in = parent.shape_out + -- TODO: allow a single dummy dimension like numpy. + assert(prod(self.shape_in) == prod(self.shape_out), + "input shape does not fit into given shape.") +end + +function Reshape:forward(X) + local bs = checkshape(X, self.shape_in) + if bs ~= self.bs then self:reset_cache(bs) end + local Y = self.cache + + for i, v in ipairs(X) do Y[i] = v end + + return Y +end + function Relu:init() Layer.init(self, "Relu") end @@ -520,10 +545,36 @@ function Dense:forward(X) local Y = self.cache dot(X, self.coeffs, 2, 1, Y) + for i, v in ipairs(Y) do Y[i] = v + self.biases[i] end - for i = 1, self.dim do - Y[i] = Y[i] + self.biases[i] - end + checkshape(Y, self.shape_out) + return Y +end + +function DenseBroadcast:init(dim) + -- same as Dense but applies the same to every m of (m, n). + Layer.init(self, "DenseBroadcast") + assert(type(dim) == "number") + self.dim = dim + self.coeffs = self:_new_weights(init_he_normal) -- should be normal, but... + self.biases = self:_new_weights(init_zeros) +end + +function DenseBroadcast:make_shape(parent) + self.shape_in = parent.shape_out + assert(#self.shape_in == 2) + self.shape_out = {self.shape_in[1], self.dim} + self.coeffs.shape = {self.shape_in[#self.shape_in], self.dim} + self.biases.shape = {1, self.dim} +end + +function DenseBroadcast:forward(X) + local bs = checkshape(X, self.shape_in) + if self.bs ~= bs then self:reset_cache(bs) end + local Y = self.cache + + dot(X, self.coeffs, 3, 1, Y) + for i, v in ipairs(Y) do Y[i] = v + self.biases[(i - 1) % self.dim + 1] end checkshape(Y, self.shape_out) return Y @@ -763,11 +814,13 @@ return { Model = Model, Input = Input, Merge = Merge, + Reshape = Reshape, Relu = Relu, Gelu = Gelu, Cos = Cos, Tanh = Tanh, Dense = Dense, + DenseBroadcast = DenseBroadcast, Softmax = Softmax, Embed = Embed, LayerNorm = LayerNorm,