make network linear

This commit is contained in:
Connor Olding 2018-06-09 16:20:07 +02:00
parent 2b4bffb401
commit dd5ec3dbde

View file

@ -150,6 +150,7 @@ local function make_network(input_size)
nn_x:feed(nn_y) nn_x:feed(nn_y)
nn_ty:feed(nn_y) nn_ty:feed(nn_y)
--[[
nn_y = nn_y:feed(nn.Dense(128)) nn_y = nn_y:feed(nn.Dense(128))
if cfg.deterministic then if cfg.deterministic then
nn_y = nn_y:feed(nn.Relu()) nn_y = nn_y:feed(nn.Relu())
@ -157,6 +158,7 @@ local function make_network(input_size)
nn_y = nn_y:feed(nn.Gelu()) nn_y = nn_y:feed(nn.Gelu())
end end
if cfg.layernorm then nn_y = nn_y:feed(nn.LayerNorm()) end if cfg.layernorm then nn_y = nn_y:feed(nn.LayerNorm()) end
--]]
nn_z = nn_y nn_z = nn_y
nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut)) nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut))