diff --git a/main.lua b/main.lua index 96b1894..cf194ee 100644 --- a/main.lua +++ b/main.lua @@ -682,6 +682,7 @@ end local function init() network = make_network(input_size, learn_start_select and 8 or 6) network:reset() + network:print() print("parameters:", network.n_param) emu.poweron() diff --git a/nn.lua b/nn.lua index 4da53aa..72c9789 100644 --- a/nn.lua +++ b/nn.lua @@ -700,6 +700,7 @@ end function Model:reset() self.n_param = 0 for _, node in ipairs(self.nodes) do + print(node.name, node:get_size()) node:init_weights() self.n_param = self.n_param + node:get_size() end @@ -730,7 +731,15 @@ function Model:cleargrad() end function Model:print() - error("TODO") -- TODO + print("digraph G {") + for _, parent in ipairs(self.nodes) do + if #parent.children then + for _, child in ipairs(parent.children) do + print('\t'..parent.name..'->'..child.name..';') + end + end + end + print('}') end function Model:collect()