diff --git a/config.lua b/config.lua index a128f83..562f686 100644 --- a/config.lua +++ b/config.lua @@ -26,7 +26,9 @@ local defaults = { time_inputs = true, -- insert binary inputs of a frame counter. -- network layers: - layernorm = false, -- (doesn't do anything right now) + hidden = false, -- use a hidden layer with ReLU/GELU activation. + hidden_size = 128, + layernorm = false, -- use a LayerNorm layer after said activation. reduce_tiles = false, bias_out = true, diff --git a/main.lua b/main.lua index 03ba74e..7a62b45 100644 --- a/main.lua +++ b/main.lua @@ -163,15 +163,15 @@ local function make_network(input_size) nn_x:feed(nn_y) nn_tz:feed(nn_y) - --[[ - nn_y = nn_y:feed(nn.Dense(128)) - if cfg.deterministic then - nn_y = nn_y:feed(nn.Relu()) - else - nn_y = nn_y:feed(nn.Gelu()) + if cfg.hidden then + nn_y = nn_y:feed(nn.Dense(cfg.hidden_size, true)) + if cfg.deterministic then + nn_y = nn_y:feed(nn.Relu()) + else + nn_y = nn_y:feed(nn.Gelu()) + end + if cfg.layernorm then nn_y = nn_y:feed(nn.LayerNorm()) end end - if cfg.layernorm then nn_y = nn_y:feed(nn.LayerNorm()) end - --]] nn_z = nn_y nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut), true, cfg.bias_out)