diff --git a/main.lua b/main.lua index 82e1ba0..82b4b64 100644 --- a/main.lua +++ b/main.lua @@ -141,14 +141,18 @@ package.loaded['nn'] = nil -- DEBUG local nn = require("nn") local network -local nn_x, nn_tx, nn_ty, nn_y, nn_z +local nn_x, nn_tx, nn_ty, nn_tz, nn_y, nn_z local function make_network(input_size) nn_x = nn.Input({input_size}) nn_tx = nn.Input({gcfg.tile_count}) nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2)) + nn_tz = nn_ty:feed(nn.Reshape{13, 17 * 2}) + nn_tz = nn_tz:feed(nn.DenseBroadcast(5)) + nn_tz = nn_tz:feed(nn.Relu()) + -- note: due to a quirk in Merge, we don't need to flatten nn_tz. nn_y = nn.Merge() nn_x:feed(nn_y) - nn_ty:feed(nn_y) + nn_tz:feed(nn_y) --[[ nn_y = nn_y:feed(nn.Dense(128))