diff --git a/onn.py b/onn.py index 5f0c05b..35ac465 100755 --- a/onn.py +++ b/onn.py @@ -893,11 +893,13 @@ def run(program, args=None): # Model Information {{{2 + print('digraph G {') for node in model.ordered_nodes: children = [str(n) for n in node.children] if children: sep = '->' - print(str(node) + sep + ('\n' + str(node) + sep).join(children)) + print('\t' + str(node) + sep + (';\n\t' + str(node) + sep).join(children) + ';') + print('}') log('parameters', model.param_count) # Training {{{2 diff --git a/onn_core.py b/onn_core.py index fecb29c..e8a6458 100644 --- a/onn_core.py +++ b/onn_core.py @@ -24,6 +24,52 @@ _pi = _f(np.pi) class LayerIncompatibility(Exception): pass +# Node Traversal {{{1 + +class DummyNode: + name = "Dummy" + + def __init__(self, children=None, parents=None): + self.children = children if children is not None else [] + self.parents = parents if parents is not None else [] + +def traverse(node_in, node_out, nodes=None, dummy_mode=False): + # i have no idea if this is any algorithm in particular. + nodes = nodes if nodes is not None else [] + + seen_up = {} + q = [node_out] + while len(q) > 0: + node = q.pop(0) + seen_up[node] = True + for parent in node.parents: + q.append(parent) + + if dummy_mode: + seen_up[node_in] = True + + nodes = [] + q = [node_in] + while len(q) > 0: + node = q.pop(0) + if not seen_up[node]: + continue + parents_added = (parent in nodes for parent in node.parents) + if not node in nodes and all(parents_added): + nodes.append(node) + for child in node.children: + q.append(child) + + if dummy_mode: + nodes.remove(node_in) + + return nodes + +def traverse_all(nodes_in, nodes_out, nodes=None): + all_in = DummyNode(children=nodes_in) + all_out = DummyNode(parents=nodes_out) + return traverse(all_in, all_out, nodes, dummy_mode=True) + # Initializations {{{1 # note: these are currently only implemented for 2D shapes. @@ -716,23 +762,30 @@ class Dense(Layer): # Models {{{1 class Model: - def __init__(self, x, y, unsafe=False): - assert isinstance(x, Layer), x - assert isinstance(y, Layer), y - self.x = x - self.y = y - self.ordered_nodes = self.traverse([], self.y) + def __init__(self, nodes_in, nodes_out, unsafe=False): + nodes_in = [nodes_in] if isinstance(nodes_in, Layer) else nodes_in + nodes_out = [nodes_out] if isinstance(nodes_out, Layer) else nodes_out + assert type(nodes_in) == list, type(nodes_in) + assert type(nodes_out) == list, type(nodes_out) + self.nodes_in = nodes_in + self.nodes_out = nodes_out + self.nodes = traverse_all(self.nodes_in, self.nodes_out) self.make_weights() - for node in self.ordered_nodes: + for node in self.nodes: node.unsafe = unsafe + @property + def ordered_nodes(self): + # deprecated? we don't guarantee an order like we did before. + return self.nodes + def make_weights(self): - self.param_count = sum((node.size for node in self.ordered_nodes)) + self.param_count = sum((node.size for node in self.nodes)) self.W = np.zeros(self.param_count, dtype=_f) self.dW = np.zeros(self.param_count, dtype=_f) offset = 0 - for node in self.ordered_nodes: + for node in self.nodes: if node.size > 0: inner_offset = 0 @@ -752,39 +805,26 @@ class Model: assert inner_offset >= node.size, "Layer {} allocated less weights than it said it would".format(node) offset += node.size - def traverse(self, nodes, node): - if node == self.x: - return [node] - for parent in node.parents: - if parent not in nodes: - new_nodes = self.traverse(nodes, parent) - for new_node in new_nodes: - if new_node not in nodes: - nodes.append(new_node) - if nodes: - nodes.append(node) - return nodes - def forward(self, X, deterministic=False): values = dict() - input_node = self.ordered_nodes[0] - output_node = self.ordered_nodes[-1] + input_node = self.nodes[0] + output_node = self.nodes[-1] values[input_node] = input_node._propagate(np.expand_dims(X, 0), deterministic) - for node in self.ordered_nodes[1:]: + for node in self.nodes[1:]: values[node] = node.propagate(values, deterministic) return values[output_node] def backward(self, error): values = dict() - output_node = self.ordered_nodes[-1] + output_node = self.nodes[-1] values[output_node] = output_node._backpropagate(np.expand_dims(error, 0)) - for node in reversed(self.ordered_nodes[:-1]): + for node in reversed(self.nodes[:-1]): values[node] = node.backpropagate(values) return self.dW def regulate_forward(self): loss = _0 - for node in self.ordered_nodes: + for node in self.nodes: if node.loss is not None: loss += node.loss for k, w in node.weights.items(): @@ -792,7 +832,7 @@ class Model: return loss def regulate(self): - for node in self.ordered_nodes: + for node in self.nodes: for k, w in node.weights.items(): w.update() @@ -812,7 +852,7 @@ class Model: for k in weights.keys(): used[k] = False - nodes = [node for node in self.ordered_nodes if node.size > 0] + nodes = [node for node in self.nodes if node.size > 0] for node in nodes: full_name = str(node).lower() for s_name, o_name in node.serialized.items(): @@ -833,7 +873,7 @@ class Model: counts = defaultdict(lambda: 0) - nodes = [node for node in self.ordered_nodes if node.size > 0] + nodes = [node for node in self.nodes if node.size > 0] for node in nodes: full_name = str(node).lower() grp = f.create_group(full_name)