add graphviz printing and stuff

This commit is contained in:
Connor Olding 2017-09-07 23:06:43 +00:00
parent acc8378980
commit 5a8c0f6140
2 changed files with 11 additions and 1 deletions

View file

@ -682,6 +682,7 @@ end
local function init() local function init()
network = make_network(input_size, learn_start_select and 8 or 6) network = make_network(input_size, learn_start_select and 8 or 6)
network:reset() network:reset()
network:print()
print("parameters:", network.n_param) print("parameters:", network.n_param)
emu.poweron() emu.poweron()

11
nn.lua
View file

@ -700,6 +700,7 @@ end
function Model:reset() function Model:reset()
self.n_param = 0 self.n_param = 0
for _, node in ipairs(self.nodes) do for _, node in ipairs(self.nodes) do
print(node.name, node:get_size())
node:init_weights() node:init_weights()
self.n_param = self.n_param + node:get_size() self.n_param = self.n_param + node:get_size()
end end
@ -730,7 +731,15 @@ function Model:cleargrad()
end end
function Model:print() function Model:print()
error("TODO") -- TODO print("digraph G {")
for _, parent in ipairs(self.nodes) do
if #parent.children then
for _, child in ipairs(parent.children) do
print('\t'..parent.name..'->'..child.name..';')
end
end
end
print('}')
end end
function Model:collect() function Model:collect()