support multiple nodal inputs
This commit is contained in:
parent
2d4ce31c7e
commit
db603753f4
2 changed files with 4 additions and 2 deletions
2
main.lua
2
main.lua
|
@ -829,7 +829,7 @@ while true do
|
||||||
if enable_network and get_state() == 'playing' or ingame_paused then
|
if enable_network and get_state() == 'playing' or ingame_paused then
|
||||||
local choose = deterministic and argmax2 or rchoice2
|
local choose = deterministic and argmax2 or rchoice2
|
||||||
|
|
||||||
local outputs = network:forward(X)
|
local outputs = network:forward({[nn_x]=X})
|
||||||
|
|
||||||
-- TODO: predict the *rewards* of all possible actions?
|
-- TODO: predict the *rewards* of all possible actions?
|
||||||
-- that's how DQN seems to work anyway.
|
-- that's how DQN seems to work anyway.
|
||||||
|
|
4
nn.lua
4
nn.lua
|
@ -374,12 +374,14 @@ function Model:reset()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function Model:forward(X)
|
function Model:forward(inputs)
|
||||||
local values = {}
|
local values = {}
|
||||||
local outputs = {}
|
local outputs = {}
|
||||||
for i, node in ipairs(self.nodes) do
|
for i, node in ipairs(self.nodes) do
|
||||||
--print(i, node.name)
|
--print(i, node.name)
|
||||||
if contains(self.nodes_in, node) then
|
if contains(self.nodes_in, node) then
|
||||||
|
local X = inputs[node]
|
||||||
|
assert(X ~= nil, ("missing input for node %s"):format(node.name))
|
||||||
values[node] = node:_propagate({X})
|
values[node] = node:_propagate({X})
|
||||||
else
|
else
|
||||||
values[node] = node:propagate(values)
|
values[node] = node:propagate(values)
|
||||||
|
|
Loading…
Add table
Reference in a new issue