add Reshape and DenseBroadcast layers
This commit is contained in:
parent
ae331ce60b
commit
2b4bffb401
1 changed files with 56 additions and 3 deletions
59
nn.lua
59
nn.lua
|
@ -290,11 +290,13 @@ local Layer = Base:extend()
|
||||||
local Model = Base:extend()
|
local Model = Base:extend()
|
||||||
local Input = Layer:extend()
|
local Input = Layer:extend()
|
||||||
local Merge = Layer:extend()
|
local Merge = Layer:extend()
|
||||||
|
local Reshape = Layer:extend()
|
||||||
local Relu = Layer:extend()
|
local Relu = Layer:extend()
|
||||||
local Gelu = Layer:extend()
|
local Gelu = Layer:extend()
|
||||||
local Cos = Layer:extend()
|
local Cos = Layer:extend()
|
||||||
local Tanh = Layer:extend()
|
local Tanh = Layer:extend()
|
||||||
local Dense = Layer:extend()
|
local Dense = Layer:extend()
|
||||||
|
local DenseBroadcast = Layer:extend()
|
||||||
local Softmax = Layer:extend()
|
local Softmax = Layer:extend()
|
||||||
local Embed = Layer:extend()
|
local Embed = Layer:extend()
|
||||||
local LayerNorm = Layer:extend()
|
local LayerNorm = Layer:extend()
|
||||||
|
@ -436,6 +438,29 @@ function Merge:_propagate(edges, deterministic)
|
||||||
return Y
|
return Y
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Reshape:init(shape)
|
||||||
|
Layer.init(self, "Reshape")
|
||||||
|
self.size = 0
|
||||||
|
self.shape_out = shape
|
||||||
|
end
|
||||||
|
|
||||||
|
function Reshape:make_shape(parent)
|
||||||
|
self.shape_in = parent.shape_out
|
||||||
|
-- TODO: allow a single dummy dimension like numpy.
|
||||||
|
assert(prod(self.shape_in) == prod(self.shape_out),
|
||||||
|
"input shape does not fit into given shape.")
|
||||||
|
end
|
||||||
|
|
||||||
|
function Reshape:forward(X)
|
||||||
|
local bs = checkshape(X, self.shape_in)
|
||||||
|
if bs ~= self.bs then self:reset_cache(bs) end
|
||||||
|
local Y = self.cache
|
||||||
|
|
||||||
|
for i, v in ipairs(X) do Y[i] = v end
|
||||||
|
|
||||||
|
return Y
|
||||||
|
end
|
||||||
|
|
||||||
function Relu:init()
|
function Relu:init()
|
||||||
Layer.init(self, "Relu")
|
Layer.init(self, "Relu")
|
||||||
end
|
end
|
||||||
|
@ -520,10 +545,36 @@ function Dense:forward(X)
|
||||||
local Y = self.cache
|
local Y = self.cache
|
||||||
|
|
||||||
dot(X, self.coeffs, 2, 1, Y)
|
dot(X, self.coeffs, 2, 1, Y)
|
||||||
|
for i, v in ipairs(Y) do Y[i] = v + self.biases[i] end
|
||||||
|
|
||||||
for i = 1, self.dim do
|
checkshape(Y, self.shape_out)
|
||||||
Y[i] = Y[i] + self.biases[i]
|
return Y
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function DenseBroadcast:init(dim)
|
||||||
|
-- same as Dense but applies the same to every m of (m, n).
|
||||||
|
Layer.init(self, "DenseBroadcast")
|
||||||
|
assert(type(dim) == "number")
|
||||||
|
self.dim = dim
|
||||||
|
self.coeffs = self:_new_weights(init_he_normal) -- should be normal, but...
|
||||||
|
self.biases = self:_new_weights(init_zeros)
|
||||||
|
end
|
||||||
|
|
||||||
|
function DenseBroadcast:make_shape(parent)
|
||||||
|
self.shape_in = parent.shape_out
|
||||||
|
assert(#self.shape_in == 2)
|
||||||
|
self.shape_out = {self.shape_in[1], self.dim}
|
||||||
|
self.coeffs.shape = {self.shape_in[#self.shape_in], self.dim}
|
||||||
|
self.biases.shape = {1, self.dim}
|
||||||
|
end
|
||||||
|
|
||||||
|
function DenseBroadcast:forward(X)
|
||||||
|
local bs = checkshape(X, self.shape_in)
|
||||||
|
if self.bs ~= bs then self:reset_cache(bs) end
|
||||||
|
local Y = self.cache
|
||||||
|
|
||||||
|
dot(X, self.coeffs, 3, 1, Y)
|
||||||
|
for i, v in ipairs(Y) do Y[i] = v + self.biases[(i - 1) % self.dim + 1] end
|
||||||
|
|
||||||
checkshape(Y, self.shape_out)
|
checkshape(Y, self.shape_out)
|
||||||
return Y
|
return Y
|
||||||
|
@ -763,11 +814,13 @@ return {
|
||||||
Model = Model,
|
Model = Model,
|
||||||
Input = Input,
|
Input = Input,
|
||||||
Merge = Merge,
|
Merge = Merge,
|
||||||
|
Reshape = Reshape,
|
||||||
Relu = Relu,
|
Relu = Relu,
|
||||||
Gelu = Gelu,
|
Gelu = Gelu,
|
||||||
Cos = Cos,
|
Cos = Cos,
|
||||||
Tanh = Tanh,
|
Tanh = Tanh,
|
||||||
Dense = Dense,
|
Dense = Dense,
|
||||||
|
DenseBroadcast = DenseBroadcast,
|
||||||
Softmax = Softmax,
|
Softmax = Softmax,
|
||||||
Embed = Embed,
|
Embed = Embed,
|
||||||
LayerNorm = LayerNorm,
|
LayerNorm = LayerNorm,
|
||||||
|
|
Loading…
Reference in a new issue