reduce tile input to 5 per row using new layers

This commit is contained in:
Connor Olding 2018-06-09 16:20:20 +02:00
parent dd5ec3dbde
commit d3e6441c40

View file

@ -141,14 +141,18 @@ package.loaded['nn'] = nil -- DEBUG
local nn = require("nn") local nn = require("nn")
local network local network
local nn_x, nn_tx, nn_ty, nn_y, nn_z local nn_x, nn_tx, nn_ty, nn_tz, nn_y, nn_z
local function make_network(input_size) local function make_network(input_size)
nn_x = nn.Input({input_size}) nn_x = nn.Input({input_size})
nn_tx = nn.Input({gcfg.tile_count}) nn_tx = nn.Input({gcfg.tile_count})
nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2)) nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2))
nn_tz = nn_ty:feed(nn.Reshape{13, 17 * 2})
nn_tz = nn_tz:feed(nn.DenseBroadcast(5))
nn_tz = nn_tz:feed(nn.Relu())
-- note: due to a quirk in Merge, we don't need to flatten nn_tz.
nn_y = nn.Merge() nn_y = nn.Merge()
nn_x:feed(nn_y) nn_x:feed(nn_y)
nn_ty:feed(nn_y) nn_tz:feed(nn_y)
--[[ --[[
nn_y = nn_y:feed(nn.Dense(128)) nn_y = nn_y:feed(nn.Dense(128))