merge the right commit this time

This commit is contained in:
Connor Olding 2017-07-01 01:14:55 +00:00
commit 1352de7006
2 changed files with 33 additions and 23 deletions

4
onn.py
View file

@ -893,11 +893,13 @@ def run(program, args=None):
# Model Information {{{2
print('digraph G {')
for node in model.ordered_nodes:
children = [str(n) for n in node.children]
if children:
sep = '->'
print(str(node) + sep + ('\n' + str(node) + sep).join(children))
print('\t' + str(node) + sep + (';\n\t' + str(node) + sep).join(children) + ';')
print('}')
log('parameters', model.param_count)
# Training {{{2

View file

@ -27,40 +27,48 @@ class LayerIncompatibility(Exception):
# Node Traversal {{{1
class DummyNode:
name = "Dummy"
def __init__(self, children=None, parents=None):
self.children = children if children is not None else []
self.parents = parents if parents is not None else []
def levelorder(field, node_in, nodes=None):
# relatively inefficient. this function can be optimized.
def traverse(node_in, node_out, nodes=None, dummy_mode=False):
# i have no idea if this is any algorithm in particular.
nodes = nodes if nodes is not None else []
seen_up = {}
q = [node_out]
while len(q) > 0:
node = q.pop(0)
seen_up[node] = True
for parent in node.parents:
q.append(parent)
if dummy_mode:
seen_up[node_in] = True
nodes = []
q = [node_in]
while len(q) > 0:
node = q.pop(0)
nodes.append(node)
for child in getattr(node, field):
q.append(child)
return nodes
def traverse(node_in, node_out, nodes):
nodes = nodes if nodes is not None else []
down = levelorder('children', node_in)
up = levelorder('parents', node_out)
seen = {}
for node in up:
seen[node] = seen.get(node, 0) | 1
for node in down:
seen[node] = seen.get(node, 0) | 2
if seen[node] == 3:
if not seen_up[node]:
continue
parents_added = (parent in nodes for parent in node.parents)
if not node in nodes and all(parents_added):
nodes.append(node)
for child in node.children:
q.append(child)
if dummy_mode:
nodes.remove(node_in)
return nodes
def traverse_all(nodes_in, nodes_out, nodes=None):
all_in = DummyNode()
all_out = DummyNode()
for node in nodes_in: all_in.children.append(node)
for node in nodes_out: all_out.parents.append(node)
return traverse(all_in, all_out, nodes)
all_in = DummyNode(children=nodes_in)
all_out = DummyNode(parents=nodes_out)
return traverse(all_in, all_out, nodes, dummy_mode=True)
# Initializations {{{1