diff --git a/onn_core.py b/onn_core.py index 5c76e04..04b0fd0 100644 --- a/onn_core.py +++ b/onn_core.py @@ -589,11 +589,11 @@ class Layer: assert self.parents, self edges = [] for parent in self.parents: - # TODO: skip over irrelevant nodes (if any) - X = values[parent] - if not self.unsafe: - self.validate_input(X) - edges.append(X) + if parent in values: + X = values[parent] + if not self.unsafe: + self.validate_input(X) + edges.append(X) Y = self._propagate(edges, deterministic) if not self.unsafe: self.validate_output(Y) @@ -604,11 +604,11 @@ class Layer: assert self.children, self edges = [] for child in self.children: - # TODO: skip over irrelevant nodes (if any) - dY = values[child] - if not self.unsafe: - self.validate_output(dY) - edges.append(dY) + if child in values: + dY = values[child] + if not self.unsafe: + self.validate_output(dY) + edges.append(dY) dX = self._backpropagate(edges) if not self.unsafe: self.validate_input(dX)