support multiple nodal inputs

This commit is contained in:
Connor Olding 2017-09-07 19:09:44 +00:00
parent 2d4ce31c7e
commit db603753f4
2 changed files with 4 additions and 2 deletions

View File

@ -829,7 +829,7 @@ while true do
if enable_network and get_state() == 'playing' or ingame_paused then if enable_network and get_state() == 'playing' or ingame_paused then
local choose = deterministic and argmax2 or rchoice2 local choose = deterministic and argmax2 or rchoice2
local outputs = network:forward(X) local outputs = network:forward({[nn_x]=X})
-- TODO: predict the *rewards* of all possible actions? -- TODO: predict the *rewards* of all possible actions?
-- that's how DQN seems to work anyway. -- that's how DQN seems to work anyway.

4
nn.lua
View File

@ -374,12 +374,14 @@ function Model:reset()
end end
end end
function Model:forward(X) function Model:forward(inputs)
local values = {} local values = {}
local outputs = {} local outputs = {}
for i, node in ipairs(self.nodes) do for i, node in ipairs(self.nodes) do
--print(i, node.name) --print(i, node.name)
if contains(self.nodes_in, node) then if contains(self.nodes_in, node) then
local X = inputs[node]
assert(X ~= nil, ("missing input for node %s"):format(node.name))
values[node] = node:_propagate({X}) values[node] = node:_propagate({X})
else else
values[node] = node:propagate(values) values[node] = node:propagate(values)