diff --git a/onn.py b/onn.py index 5f0c05b..35ac465 100755 --- a/onn.py +++ b/onn.py @@ -893,11 +893,13 @@ def run(program, args=None): # Model Information {{{2 + print('digraph G {') for node in model.ordered_nodes: children = [str(n) for n in node.children] if children: sep = '->' - print(str(node) + sep + ('\n' + str(node) + sep).join(children)) + print('\t' + str(node) + sep + (';\n\t' + str(node) + sep).join(children) + ';') + print('}') log('parameters', model.param_count) # Training {{{2 diff --git a/onn_core.py b/onn_core.py index 6683046..e8a6458 100644 --- a/onn_core.py +++ b/onn_core.py @@ -27,40 +27,48 @@ class LayerIncompatibility(Exception): # Node Traversal {{{1 class DummyNode: + name = "Dummy" + def __init__(self, children=None, parents=None): self.children = children if children is not None else [] self.parents = parents if parents is not None else [] -def levelorder(field, node_in, nodes=None): - # relatively inefficient. this function can be optimized. +def traverse(node_in, node_out, nodes=None, dummy_mode=False): + # i have no idea if this is any algorithm in particular. nodes = nodes if nodes is not None else [] + + seen_up = {} + q = [node_out] + while len(q) > 0: + node = q.pop(0) + seen_up[node] = True + for parent in node.parents: + q.append(parent) + + if dummy_mode: + seen_up[node_in] = True + + nodes = [] q = [node_in] while len(q) > 0: node = q.pop(0) - nodes.append(node) - for child in getattr(node, field): - q.append(child) - return nodes - -def traverse(node_in, node_out, nodes): - nodes = nodes if nodes is not None else [] - down = levelorder('children', node_in) - up = levelorder('parents', node_out) - seen = {} - for node in up: - seen[node] = seen.get(node, 0) | 1 - for node in down: - seen[node] = seen.get(node, 0) | 2 - if seen[node] == 3: + if not seen_up[node]: + continue + parents_added = (parent in nodes for parent in node.parents) + if not node in nodes and all(parents_added): nodes.append(node) + for child in node.children: + q.append(child) + + if dummy_mode: + nodes.remove(node_in) + return nodes def traverse_all(nodes_in, nodes_out, nodes=None): - all_in = DummyNode() - all_out = DummyNode() - for node in nodes_in: all_in.children.append(node) - for node in nodes_out: all_out.parents.append(node) - return traverse(all_in, all_out, nodes) + all_in = DummyNode(children=nodes_in) + all_out = DummyNode(parents=nodes_out) + return traverse(all_in, all_out, nodes, dummy_mode=True) # Initializations {{{1