add graphviz printing and stuff
This commit is contained in:
parent
acc8378980
commit
5a8c0f6140
2 changed files with 11 additions and 1 deletions
1
main.lua
1
main.lua
|
@ -682,6 +682,7 @@ end
|
|||
local function init()
|
||||
network = make_network(input_size, learn_start_select and 8 or 6)
|
||||
network:reset()
|
||||
network:print()
|
||||
print("parameters:", network.n_param)
|
||||
|
||||
emu.poweron()
|
||||
|
|
11
nn.lua
11
nn.lua
|
@ -700,6 +700,7 @@ end
|
|||
function Model:reset()
|
||||
self.n_param = 0
|
||||
for _, node in ipairs(self.nodes) do
|
||||
print(node.name, node:get_size())
|
||||
node:init_weights()
|
||||
self.n_param = self.n_param + node:get_size()
|
||||
end
|
||||
|
@ -730,7 +731,15 @@ function Model:cleargrad()
|
|||
end
|
||||
|
||||
function Model:print()
|
||||
error("TODO") -- TODO
|
||||
print("digraph G {")
|
||||
for _, parent in ipairs(self.nodes) do
|
||||
if #parent.children then
|
||||
for _, child in ipairs(parent.children) do
|
||||
print('\t'..parent.name..'->'..child.name..';')
|
||||
end
|
||||
end
|
||||
end
|
||||
print('}')
|
||||
end
|
||||
|
||||
function Model:collect()
|
||||
|
|
Loading…
Reference in a new issue