From db603753f4612f46cbdad8d66b3104fb9c894955 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Thu, 7 Sep 2017 19:09:44 +0000 Subject: [PATCH] support multiple nodal inputs --- main.lua | 2 +- nn.lua | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/main.lua b/main.lua index db32810..fa98ce0 100644 --- a/main.lua +++ b/main.lua @@ -829,7 +829,7 @@ while true do if enable_network and get_state() == 'playing' or ingame_paused then local choose = deterministic and argmax2 or rchoice2 - local outputs = network:forward(X) + local outputs = network:forward({[nn_x]=X}) -- TODO: predict the *rewards* of all possible actions? -- that's how DQN seems to work anyway. diff --git a/nn.lua b/nn.lua index 48e123a..3b5f94f 100644 --- a/nn.lua +++ b/nn.lua @@ -374,12 +374,14 @@ function Model:reset() end end -function Model:forward(X) +function Model:forward(inputs) local values = {} local outputs = {} for i, node in ipairs(self.nodes) do --print(i, node.name) if contains(self.nodes_in, node) then + local X = inputs[node] + assert(X ~= nil, ("missing input for node %s"):format(node.name)) values[node] = node:_propagate({X}) else values[node] = node:propagate(values)