add graphviz printing and stuff
This commit is contained in:
parent
acc8378980
commit
5a8c0f6140
2 changed files with 11 additions and 1 deletions
1
main.lua
1
main.lua
|
@ -682,6 +682,7 @@ end
|
||||||
local function init()
|
local function init()
|
||||||
network = make_network(input_size, learn_start_select and 8 or 6)
|
network = make_network(input_size, learn_start_select and 8 or 6)
|
||||||
network:reset()
|
network:reset()
|
||||||
|
network:print()
|
||||||
print("parameters:", network.n_param)
|
print("parameters:", network.n_param)
|
||||||
|
|
||||||
emu.poweron()
|
emu.poweron()
|
||||||
|
|
11
nn.lua
11
nn.lua
|
@ -700,6 +700,7 @@ end
|
||||||
function Model:reset()
|
function Model:reset()
|
||||||
self.n_param = 0
|
self.n_param = 0
|
||||||
for _, node in ipairs(self.nodes) do
|
for _, node in ipairs(self.nodes) do
|
||||||
|
print(node.name, node:get_size())
|
||||||
node:init_weights()
|
node:init_weights()
|
||||||
self.n_param = self.n_param + node:get_size()
|
self.n_param = self.n_param + node:get_size()
|
||||||
end
|
end
|
||||||
|
@ -730,7 +731,15 @@ function Model:cleargrad()
|
||||||
end
|
end
|
||||||
|
|
||||||
function Model:print()
|
function Model:print()
|
||||||
error("TODO") -- TODO
|
print("digraph G {")
|
||||||
|
for _, parent in ipairs(self.nodes) do
|
||||||
|
if #parent.children then
|
||||||
|
for _, child in ipairs(parent.children) do
|
||||||
|
print('\t'..parent.name..'->'..child.name..';')
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
print('}')
|
||||||
end
|
end
|
||||||
|
|
||||||
function Model:collect()
|
function Model:collect()
|
||||||
|
|
Loading…
Reference in a new issue