reduce tile input to 5 per row using new layers
This commit is contained in:
parent
dd5ec3dbde
commit
d3e6441c40
1 changed files with 6 additions and 2 deletions
8
main.lua
8
main.lua
|
@ -141,14 +141,18 @@ package.loaded['nn'] = nil -- DEBUG
|
|||
local nn = require("nn")
|
||||
|
||||
local network
|
||||
local nn_x, nn_tx, nn_ty, nn_y, nn_z
|
||||
local nn_x, nn_tx, nn_ty, nn_tz, nn_y, nn_z
|
||||
local function make_network(input_size)
|
||||
nn_x = nn.Input({input_size})
|
||||
nn_tx = nn.Input({gcfg.tile_count})
|
||||
nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2))
|
||||
nn_tz = nn_ty:feed(nn.Reshape{13, 17 * 2})
|
||||
nn_tz = nn_tz:feed(nn.DenseBroadcast(5))
|
||||
nn_tz = nn_tz:feed(nn.Relu())
|
||||
-- note: due to a quirk in Merge, we don't need to flatten nn_tz.
|
||||
nn_y = nn.Merge()
|
||||
nn_x:feed(nn_y)
|
||||
nn_ty:feed(nn_y)
|
||||
nn_tz:feed(nn_y)
|
||||
|
||||
--[[
|
||||
nn_y = nn_y:feed(nn.Dense(128))
|
||||
|
|
Loading…
Reference in a new issue