support multiple nodal inputs

This commit is contained in:
Connor Olding 2017-09-07 19:09:44 +00:00
parent 2d4ce31c7e
commit db603753f4
2 changed files with 4 additions and 2 deletions

View File

@ -829,7 +829,7 @@ while true do
if enable_network and get_state() == 'playing' or ingame_paused then
local choose = deterministic and argmax2 or rchoice2
local outputs = network:forward(X)
local outputs = network:forward({[nn_x]=X})
-- TODO: predict the *rewards* of all possible actions?
-- that's how DQN seems to work anyway.

4
nn.lua
View File

@ -374,12 +374,14 @@ function Model:reset()
end
end
function Model:forward(X)
function Model:forward(inputs)
local values = {}
local outputs = {}
for i, node in ipairs(self.nodes) do
--print(i, node.name)
if contains(self.nodes_in, node) then
local X = inputs[node]
assert(X ~= nil, ("missing input for node %s"):format(node.name))
values[node] = node:_propagate({X})
else
values[node] = node:propagate(values)