diff --git a/main.lua b/main.lua index 004fefd..82e1ba0 100644 --- a/main.lua +++ b/main.lua @@ -150,6 +150,7 @@ local function make_network(input_size) nn_x:feed(nn_y) nn_ty:feed(nn_y) + --[[ nn_y = nn_y:feed(nn.Dense(128)) if cfg.deterministic then nn_y = nn_y:feed(nn.Relu()) @@ -157,6 +158,7 @@ local function make_network(input_size) nn_y = nn_y:feed(nn.Gelu()) end if cfg.layernorm then nn_y = nn_y:feed(nn.LayerNorm()) end + --]] nn_z = nn_y nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut))