merge the right commit this time
This commit is contained in:
commit
1352de7006
2 changed files with 33 additions and 23 deletions
4
onn.py
4
onn.py
|
@ -893,11 +893,13 @@ def run(program, args=None):
|
|||
|
||||
# Model Information {{{2
|
||||
|
||||
print('digraph G {')
|
||||
for node in model.ordered_nodes:
|
||||
children = [str(n) for n in node.children]
|
||||
if children:
|
||||
sep = '->'
|
||||
print(str(node) + sep + ('\n' + str(node) + sep).join(children))
|
||||
print('\t' + str(node) + sep + (';\n\t' + str(node) + sep).join(children) + ';')
|
||||
print('}')
|
||||
log('parameters', model.param_count)
|
||||
|
||||
# Training {{{2
|
||||
|
|
52
onn_core.py
52
onn_core.py
|
@ -27,40 +27,48 @@ class LayerIncompatibility(Exception):
|
|||
# Node Traversal {{{1
|
||||
|
||||
class DummyNode:
|
||||
name = "Dummy"
|
||||
|
||||
def __init__(self, children=None, parents=None):
|
||||
self.children = children if children is not None else []
|
||||
self.parents = parents if parents is not None else []
|
||||
|
||||
def levelorder(field, node_in, nodes=None):
|
||||
# relatively inefficient. this function can be optimized.
|
||||
def traverse(node_in, node_out, nodes=None, dummy_mode=False):
|
||||
# i have no idea if this is any algorithm in particular.
|
||||
nodes = nodes if nodes is not None else []
|
||||
|
||||
seen_up = {}
|
||||
q = [node_out]
|
||||
while len(q) > 0:
|
||||
node = q.pop(0)
|
||||
seen_up[node] = True
|
||||
for parent in node.parents:
|
||||
q.append(parent)
|
||||
|
||||
if dummy_mode:
|
||||
seen_up[node_in] = True
|
||||
|
||||
nodes = []
|
||||
q = [node_in]
|
||||
while len(q) > 0:
|
||||
node = q.pop(0)
|
||||
nodes.append(node)
|
||||
for child in getattr(node, field):
|
||||
q.append(child)
|
||||
return nodes
|
||||
|
||||
def traverse(node_in, node_out, nodes):
|
||||
nodes = nodes if nodes is not None else []
|
||||
down = levelorder('children', node_in)
|
||||
up = levelorder('parents', node_out)
|
||||
seen = {}
|
||||
for node in up:
|
||||
seen[node] = seen.get(node, 0) | 1
|
||||
for node in down:
|
||||
seen[node] = seen.get(node, 0) | 2
|
||||
if seen[node] == 3:
|
||||
if not seen_up[node]:
|
||||
continue
|
||||
parents_added = (parent in nodes for parent in node.parents)
|
||||
if not node in nodes and all(parents_added):
|
||||
nodes.append(node)
|
||||
for child in node.children:
|
||||
q.append(child)
|
||||
|
||||
if dummy_mode:
|
||||
nodes.remove(node_in)
|
||||
|
||||
return nodes
|
||||
|
||||
def traverse_all(nodes_in, nodes_out, nodes=None):
|
||||
all_in = DummyNode()
|
||||
all_out = DummyNode()
|
||||
for node in nodes_in: all_in.children.append(node)
|
||||
for node in nodes_out: all_out.parents.append(node)
|
||||
return traverse(all_in, all_out, nodes)
|
||||
all_in = DummyNode(children=nodes_in)
|
||||
all_out = DummyNode(parents=nodes_out)
|
||||
return traverse(all_in, all_out, nodes, dummy_mode=True)
|
||||
|
||||
# Initializations {{{1
|
||||
|
||||
|
|
Loading…
Reference in a new issue