allow multi-input and multi-output models
This commit is contained in:
parent
3386869b30
commit
d38e2076f0
1 changed files with 40 additions and 14 deletions
54
onn_core.py
54
onn_core.py
|
@ -932,23 +932,49 @@ class Model:
|
||||||
assert inner_offset >= node.size, "Layer {} allocated less weights than it said it would".format(node)
|
assert inner_offset >= node.size, "Layer {} allocated less weights than it said it would".format(node)
|
||||||
offset += node.size
|
offset += node.size
|
||||||
|
|
||||||
def evaluate(self, inputs, deterministic=True):
|
def evaluate(self, input_, deterministic=True):
|
||||||
values = dict()
|
assert len(self.nodes_in) == 1, "ambiguous input in multi-input network; use evaluate_multi() instead"
|
||||||
input_node = self.nodes[0]
|
assert len(self.nodes_out) == 1, "ambiguous output in multi-output network; use evaluate_multi() instead"
|
||||||
output_node = self.nodes[-1]
|
node_in = self.nodes_in[0]
|
||||||
values[input_node] = input_node._propagate(np.expand_dims(inputs, 0), deterministic)
|
node_out = self.nodes_out[0]
|
||||||
for node in self.nodes[1:]:
|
outputs = self.evaluate_multi({node_in: input_}, deterministic)
|
||||||
values[node] = node.propagate(values, deterministic)
|
return outputs[node_out]
|
||||||
return values[output_node]
|
|
||||||
|
|
||||||
def apply(self, error): # TODO: better name?
|
def apply(self, error): # TODO: better name?
|
||||||
|
assert len(self.nodes_in) == 1, "ambiguous input in multi-input network; use apply_multi() instead"
|
||||||
|
assert len(self.nodes_out) == 1, "ambiguous output in multi-output network; use apply_multi() instead"
|
||||||
|
node_in = self.nodes_in[0]
|
||||||
|
node_out = self.nodes_out[0]
|
||||||
|
inputs = self.apply_multi({node_out: error})
|
||||||
|
return inputs[node_in]
|
||||||
|
|
||||||
|
def evaluate_multi(self, inputs, deterministic=True):
|
||||||
values = dict()
|
values = dict()
|
||||||
input_node = self.nodes[0]
|
outputs = dict()
|
||||||
output_node = self.nodes[-1]
|
for node in self.nodes:
|
||||||
values[output_node] = output_node._backpropagate(np.expand_dims(error, 0))
|
if node in self.nodes_in:
|
||||||
for node in reversed(self.nodes[:-1]):
|
assert node in inputs, "missing input for node {}".format(node.name)
|
||||||
values[node] = node.backpropagate(values)
|
X = inputs[node]
|
||||||
return values[input_node]
|
values[node] = node._propagate(np.expand_dims(X, 0), deterministic)
|
||||||
|
else:
|
||||||
|
values[node] = node.propagate(values, deterministic)
|
||||||
|
if node in self.nodes_out:
|
||||||
|
outputs[node] = values[node]
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def apply_multi(self, outputs):
|
||||||
|
values = dict()
|
||||||
|
inputs = dict()
|
||||||
|
for node in reversed(self.nodes):
|
||||||
|
if node in self.nodes_out:
|
||||||
|
assert node in outputs, "missing output for node {}".format(node.name)
|
||||||
|
X = outputs[node]
|
||||||
|
values[node] = node._backpropagate(np.expand_dims(X, 0))
|
||||||
|
else:
|
||||||
|
values[node] = node.backpropagate(values)
|
||||||
|
if node in self.nodes_in:
|
||||||
|
inputs[node] = values[node]
|
||||||
|
return inputs
|
||||||
|
|
||||||
def forward(self, inputs, outputs, measure=False, deterministic=False):
|
def forward(self, inputs, outputs, measure=False, deterministic=False):
|
||||||
predicted = self.evaluate(inputs, deterministic=deterministic)
|
predicted = self.evaluate(inputs, deterministic=deterministic)
|
||||||
|
|
Loading…
Add table
Reference in a new issue