move graph printing into Model class
This commit is contained in:
parent
1352de7006
commit
7da93e93a8
3 changed files with 13 additions and 12 deletions
8
onn.py
8
onn.py
|
@ -893,13 +893,7 @@ def run(program, args=None):
|
|||
|
||||
# Model Information {{{2
|
||||
|
||||
print('digraph G {')
|
||||
for node in model.ordered_nodes:
|
||||
children = [str(n) for n in node.children]
|
||||
if children:
|
||||
sep = '->'
|
||||
print('\t' + str(node) + sep + (';\n\t' + str(node) + sep).join(children) + ';')
|
||||
print('}')
|
||||
model.print_graph()
|
||||
log('parameters', model.param_count)
|
||||
|
||||
# Training {{{2
|
||||
|
|
11
onn_core.py
11
onn_core.py
|
@ -1,3 +1,5 @@
|
|||
import sys
|
||||
|
||||
import numpy as np
|
||||
_f = np.float32
|
||||
|
||||
|
@ -889,6 +891,15 @@ class Model:
|
|||
|
||||
f.close()
|
||||
|
||||
def print_graph(self, file=sys.stdout):
|
||||
print('digraph G {', file=file)
|
||||
for node in self.nodes:
|
||||
children = [str(n) for n in node.children]
|
||||
if children:
|
||||
sep = '->'
|
||||
print('\t' + str(node) + sep + (';\n\t' + str(node) + sep).join(children) + ';', file=file)
|
||||
print('}', file=file)
|
||||
|
||||
# Rituals {{{1
|
||||
|
||||
class Ritual: # i'm just making up names at this point
|
||||
|
|
|
@ -148,11 +148,7 @@ ritual = Ritual(learner=learner, loss=loss, mloss=mloss)
|
|||
#ritual = NoisyRitual(learner=learner, loss=loss, mloss=mloss,
|
||||
# input_noise=1e-1, output_noise=3.2e-2, gradient_noise=1e-1)
|
||||
|
||||
for node in model.ordered_nodes:
|
||||
children = [str(n) for n in node.children]
|
||||
if children:
|
||||
sep = '->'
|
||||
print(str(node) + sep + ('\n' + str(node) + sep).join(children))
|
||||
model.print_graph()
|
||||
log('parameters', model.param_count)
|
||||
|
||||
ritual.prepare(model)
|
||||
|
|
Loading…
Reference in a new issue