diff --git a/onn_core.py b/onn_core.py index fecb29c..6683046 100644 --- a/onn_core.py +++ b/onn_core.py @@ -24,6 +24,44 @@ _pi = _f(np.pi) class LayerIncompatibility(Exception): pass +# Node Traversal {{{1 + +class DummyNode: + def __init__(self, children=None, parents=None): + self.children = children if children is not None else [] + self.parents = parents if parents is not None else [] + +def levelorder(field, node_in, nodes=None): + # relatively inefficient. this function can be optimized. + nodes = nodes if nodes is not None else [] + q = [node_in] + while len(q) > 0: + node = q.pop(0) + nodes.append(node) + for child in getattr(node, field): + q.append(child) + return nodes + +def traverse(node_in, node_out, nodes): + nodes = nodes if nodes is not None else [] + down = levelorder('children', node_in) + up = levelorder('parents', node_out) + seen = {} + for node in up: + seen[node] = seen.get(node, 0) | 1 + for node in down: + seen[node] = seen.get(node, 0) | 2 + if seen[node] == 3: + nodes.append(node) + return nodes + +def traverse_all(nodes_in, nodes_out, nodes=None): + all_in = DummyNode() + all_out = DummyNode() + for node in nodes_in: all_in.children.append(node) + for node in nodes_out: all_out.parents.append(node) + return traverse(all_in, all_out, nodes) + # Initializations {{{1 # note: these are currently only implemented for 2D shapes. @@ -716,23 +754,30 @@ class Dense(Layer): # Models {{{1 class Model: - def __init__(self, x, y, unsafe=False): - assert isinstance(x, Layer), x - assert isinstance(y, Layer), y - self.x = x - self.y = y - self.ordered_nodes = self.traverse([], self.y) + def __init__(self, nodes_in, nodes_out, unsafe=False): + nodes_in = [nodes_in] if isinstance(nodes_in, Layer) else nodes_in + nodes_out = [nodes_out] if isinstance(nodes_out, Layer) else nodes_out + assert type(nodes_in) == list, type(nodes_in) + assert type(nodes_out) == list, type(nodes_out) + self.nodes_in = nodes_in + self.nodes_out = nodes_out + self.nodes = traverse_all(self.nodes_in, self.nodes_out) self.make_weights() - for node in self.ordered_nodes: + for node in self.nodes: node.unsafe = unsafe + @property + def ordered_nodes(self): + # deprecated? we don't guarantee an order like we did before. + return self.nodes + def make_weights(self): - self.param_count = sum((node.size for node in self.ordered_nodes)) + self.param_count = sum((node.size for node in self.nodes)) self.W = np.zeros(self.param_count, dtype=_f) self.dW = np.zeros(self.param_count, dtype=_f) offset = 0 - for node in self.ordered_nodes: + for node in self.nodes: if node.size > 0: inner_offset = 0 @@ -752,39 +797,26 @@ class Model: assert inner_offset >= node.size, "Layer {} allocated less weights than it said it would".format(node) offset += node.size - def traverse(self, nodes, node): - if node == self.x: - return [node] - for parent in node.parents: - if parent not in nodes: - new_nodes = self.traverse(nodes, parent) - for new_node in new_nodes: - if new_node not in nodes: - nodes.append(new_node) - if nodes: - nodes.append(node) - return nodes - def forward(self, X, deterministic=False): values = dict() - input_node = self.ordered_nodes[0] - output_node = self.ordered_nodes[-1] + input_node = self.nodes[0] + output_node = self.nodes[-1] values[input_node] = input_node._propagate(np.expand_dims(X, 0), deterministic) - for node in self.ordered_nodes[1:]: + for node in self.nodes[1:]: values[node] = node.propagate(values, deterministic) return values[output_node] def backward(self, error): values = dict() - output_node = self.ordered_nodes[-1] + output_node = self.nodes[-1] values[output_node] = output_node._backpropagate(np.expand_dims(error, 0)) - for node in reversed(self.ordered_nodes[:-1]): + for node in reversed(self.nodes[:-1]): values[node] = node.backpropagate(values) return self.dW def regulate_forward(self): loss = _0 - for node in self.ordered_nodes: + for node in self.nodes: if node.loss is not None: loss += node.loss for k, w in node.weights.items(): @@ -792,7 +824,7 @@ class Model: return loss def regulate(self): - for node in self.ordered_nodes: + for node in self.nodes: for k, w in node.weights.items(): w.update() @@ -812,7 +844,7 @@ class Model: for k in weights.keys(): used[k] = False - nodes = [node for node in self.ordered_nodes if node.size > 0] + nodes = [node for node in self.nodes if node.size > 0] for node in nodes: full_name = str(node).lower() for s_name, o_name in node.serialized.items(): @@ -833,7 +865,7 @@ class Model: counts = defaultdict(lambda: 0) - nodes = [node for node in self.ordered_nodes if node.size > 0] + nodes = [node for node in self.nodes if node.size > 0] for node in nodes: full_name = str(node).lower() grp = f.create_group(full_name)