merge the right commit this time
This commit is contained in:
commit
1352de7006
4
onn.py
4
onn.py
|
@ -893,11 +893,13 @@ def run(program, args=None):
|
||||||
|
|
||||||
# Model Information {{{2
|
# Model Information {{{2
|
||||||
|
|
||||||
|
print('digraph G {')
|
||||||
for node in model.ordered_nodes:
|
for node in model.ordered_nodes:
|
||||||
children = [str(n) for n in node.children]
|
children = [str(n) for n in node.children]
|
||||||
if children:
|
if children:
|
||||||
sep = '->'
|
sep = '->'
|
||||||
print(str(node) + sep + ('\n' + str(node) + sep).join(children))
|
print('\t' + str(node) + sep + (';\n\t' + str(node) + sep).join(children) + ';')
|
||||||
|
print('}')
|
||||||
log('parameters', model.param_count)
|
log('parameters', model.param_count)
|
||||||
|
|
||||||
# Training {{{2
|
# Training {{{2
|
||||||
|
|
52
onn_core.py
52
onn_core.py
|
@ -27,40 +27,48 @@ class LayerIncompatibility(Exception):
|
||||||
# Node Traversal {{{1
|
# Node Traversal {{{1
|
||||||
|
|
||||||
class DummyNode:
|
class DummyNode:
|
||||||
|
name = "Dummy"
|
||||||
|
|
||||||
def __init__(self, children=None, parents=None):
|
def __init__(self, children=None, parents=None):
|
||||||
self.children = children if children is not None else []
|
self.children = children if children is not None else []
|
||||||
self.parents = parents if parents is not None else []
|
self.parents = parents if parents is not None else []
|
||||||
|
|
||||||
def levelorder(field, node_in, nodes=None):
|
def traverse(node_in, node_out, nodes=None, dummy_mode=False):
|
||||||
# relatively inefficient. this function can be optimized.
|
# i have no idea if this is any algorithm in particular.
|
||||||
nodes = nodes if nodes is not None else []
|
nodes = nodes if nodes is not None else []
|
||||||
|
|
||||||
|
seen_up = {}
|
||||||
|
q = [node_out]
|
||||||
|
while len(q) > 0:
|
||||||
|
node = q.pop(0)
|
||||||
|
seen_up[node] = True
|
||||||
|
for parent in node.parents:
|
||||||
|
q.append(parent)
|
||||||
|
|
||||||
|
if dummy_mode:
|
||||||
|
seen_up[node_in] = True
|
||||||
|
|
||||||
|
nodes = []
|
||||||
q = [node_in]
|
q = [node_in]
|
||||||
while len(q) > 0:
|
while len(q) > 0:
|
||||||
node = q.pop(0)
|
node = q.pop(0)
|
||||||
nodes.append(node)
|
if not seen_up[node]:
|
||||||
for child in getattr(node, field):
|
continue
|
||||||
q.append(child)
|
parents_added = (parent in nodes for parent in node.parents)
|
||||||
return nodes
|
if not node in nodes and all(parents_added):
|
||||||
|
|
||||||
def traverse(node_in, node_out, nodes):
|
|
||||||
nodes = nodes if nodes is not None else []
|
|
||||||
down = levelorder('children', node_in)
|
|
||||||
up = levelorder('parents', node_out)
|
|
||||||
seen = {}
|
|
||||||
for node in up:
|
|
||||||
seen[node] = seen.get(node, 0) | 1
|
|
||||||
for node in down:
|
|
||||||
seen[node] = seen.get(node, 0) | 2
|
|
||||||
if seen[node] == 3:
|
|
||||||
nodes.append(node)
|
nodes.append(node)
|
||||||
|
for child in node.children:
|
||||||
|
q.append(child)
|
||||||
|
|
||||||
|
if dummy_mode:
|
||||||
|
nodes.remove(node_in)
|
||||||
|
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
def traverse_all(nodes_in, nodes_out, nodes=None):
|
def traverse_all(nodes_in, nodes_out, nodes=None):
|
||||||
all_in = DummyNode()
|
all_in = DummyNode(children=nodes_in)
|
||||||
all_out = DummyNode()
|
all_out = DummyNode(parents=nodes_out)
|
||||||
for node in nodes_in: all_in.children.append(node)
|
return traverse(all_in, all_out, nodes, dummy_mode=True)
|
||||||
for node in nodes_out: all_out.parents.append(node)
|
|
||||||
return traverse(all_in, all_out, nodes)
|
|
||||||
|
|
||||||
# Initializations {{{1
|
# Initializations {{{1
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user