make network linear
This commit is contained in:
parent
2b4bffb401
commit
dd5ec3dbde
1 changed files with 2 additions and 0 deletions
2
main.lua
2
main.lua
|
@ -150,6 +150,7 @@ local function make_network(input_size)
|
||||||
nn_x:feed(nn_y)
|
nn_x:feed(nn_y)
|
||||||
nn_ty:feed(nn_y)
|
nn_ty:feed(nn_y)
|
||||||
|
|
||||||
|
--[[
|
||||||
nn_y = nn_y:feed(nn.Dense(128))
|
nn_y = nn_y:feed(nn.Dense(128))
|
||||||
if cfg.deterministic then
|
if cfg.deterministic then
|
||||||
nn_y = nn_y:feed(nn.Relu())
|
nn_y = nn_y:feed(nn.Relu())
|
||||||
|
@ -157,6 +158,7 @@ local function make_network(input_size)
|
||||||
nn_y = nn_y:feed(nn.Gelu())
|
nn_y = nn_y:feed(nn.Gelu())
|
||||||
end
|
end
|
||||||
if cfg.layernorm then nn_y = nn_y:feed(nn.LayerNorm()) end
|
if cfg.layernorm then nn_y = nn_y:feed(nn.LayerNorm()) end
|
||||||
|
--]]
|
||||||
|
|
||||||
nn_z = nn_y
|
nn_z = nn_y
|
||||||
nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut))
|
nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut))
|
||||||
|
|
Loading…
Reference in a new issue