diff --git a/main.lua b/main.lua index 91334c7..90f1fdb 100644 --- a/main.lua +++ b/main.lua @@ -233,12 +233,12 @@ local function make_network(input_size) nn_ty:feed(nn_y) nn_y = nn_y:feed(nn.Dense(128)) - if cfg.layernorm then nn_y = nn_y:feed(nn.LayerNorm()) end if cfg.deterministic then nn_y = nn_y:feed(nn.Relu()) else nn_y = nn_y:feed(nn.Gelu()) end + if cfg.layernorm then nn_y = nn_y:feed(nn.LayerNorm()) end nn_z = nn_y nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut))