From 7da93e93a814eccc98a70781b1a38046ca565aec Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Sat, 1 Jul 2017 02:17:46 +0000 Subject: [PATCH] move graph printing into Model class --- onn.py | 8 +------- onn_core.py | 11 +++++++++++ onn_mnist.py | 6 +----- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/onn.py b/onn.py index 35ac465..335dcc8 100755 --- a/onn.py +++ b/onn.py @@ -893,13 +893,7 @@ def run(program, args=None): # Model Information {{{2 - print('digraph G {') - for node in model.ordered_nodes: - children = [str(n) for n in node.children] - if children: - sep = '->' - print('\t' + str(node) + sep + (';\n\t' + str(node) + sep).join(children) + ';') - print('}') + model.print_graph() log('parameters', model.param_count) # Training {{{2 diff --git a/onn_core.py b/onn_core.py index e8a6458..59f5e7d 100644 --- a/onn_core.py +++ b/onn_core.py @@ -1,3 +1,5 @@ +import sys + import numpy as np _f = np.float32 @@ -889,6 +891,15 @@ class Model: f.close() + def print_graph(self, file=sys.stdout): + print('digraph G {', file=file) + for node in self.nodes: + children = [str(n) for n in node.children] + if children: + sep = '->' + print('\t' + str(node) + sep + (';\n\t' + str(node) + sep).join(children) + ';', file=file) + print('}', file=file) + # Rituals {{{1 class Ritual: # i'm just making up names at this point diff --git a/onn_mnist.py b/onn_mnist.py index 46e2f91..4fdcfe0 100755 --- a/onn_mnist.py +++ b/onn_mnist.py @@ -148,11 +148,7 @@ ritual = Ritual(learner=learner, loss=loss, mloss=mloss) #ritual = NoisyRitual(learner=learner, loss=loss, mloss=mloss, # input_noise=1e-1, output_noise=3.2e-2, gradient_noise=1e-1) -for node in model.ordered_nodes: - children = [str(n) for n in node.children] - if children: - sep = '->' - print(str(node) + sep + ('\n' + str(node) + sep).join(children)) +model.print_graph() log('parameters', model.param_count) ritual.prepare(model)