begin work on multiple input/output nodes

This commit is contained in:
Connor Olding 2017-07-01 00:44:56 +00:00
parent a7c4bdaa2e
commit 69786b40a1
2 changed files with 74 additions and 32 deletions

4
onn.py
View File

@ -893,11 +893,13 @@ def run(program, args=None):
# Model Information {{{2
print('digraph G {')
for node in model.ordered_nodes:
children = [str(n) for n in node.children]
if children:
sep = '->'
print(str(node) + sep + ('\n' + str(node) + sep).join(children))
print('\t' + str(node) + sep + (';\n\t' + str(node) + sep).join(children) + ';')
print('}')
log('parameters', model.param_count)
# Training {{{2

View File

@ -24,6 +24,52 @@ _pi = _f(np.pi)
class LayerIncompatibility(Exception):
pass
# Node Traversal {{{1
class DummyNode:
name = "Dummy"
def __init__(self, children=None, parents=None):
self.children = children if children is not None else []
self.parents = parents if parents is not None else []
def traverse(node_in, node_out, nodes=None, dummy_mode=False):
# i have no idea if this is any algorithm in particular.
nodes = nodes if nodes is not None else []
seen_up = {}
q = [node_out]
while len(q) > 0:
node = q.pop(0)
seen_up[node] = True
for parent in node.parents:
q.append(parent)
if dummy_mode:
seen_up[node_in] = True
nodes = []
q = [node_in]
while len(q) > 0:
node = q.pop(0)
if not seen_up[node]:
continue
parents_added = (parent in nodes for parent in node.parents)
if not node in nodes and all(parents_added):
nodes.append(node)
for child in node.children:
q.append(child)
if dummy_mode:
nodes.remove(node_in)
return nodes
def traverse_all(nodes_in, nodes_out, nodes=None):
all_in = DummyNode(children=nodes_in)
all_out = DummyNode(parents=nodes_out)
return traverse(all_in, all_out, nodes, dummy_mode=True)
# Initializations {{{1
# note: these are currently only implemented for 2D shapes.
@ -716,23 +762,30 @@ class Dense(Layer):
# Models {{{1
class Model:
def __init__(self, x, y, unsafe=False):
assert isinstance(x, Layer), x
assert isinstance(y, Layer), y
self.x = x
self.y = y
self.ordered_nodes = self.traverse([], self.y)
def __init__(self, nodes_in, nodes_out, unsafe=False):
nodes_in = [nodes_in] if isinstance(nodes_in, Layer) else nodes_in
nodes_out = [nodes_out] if isinstance(nodes_out, Layer) else nodes_out
assert type(nodes_in) == list, type(nodes_in)
assert type(nodes_out) == list, type(nodes_out)
self.nodes_in = nodes_in
self.nodes_out = nodes_out
self.nodes = traverse_all(self.nodes_in, self.nodes_out)
self.make_weights()
for node in self.ordered_nodes:
for node in self.nodes:
node.unsafe = unsafe
@property
def ordered_nodes(self):
# deprecated? we don't guarantee an order like we did before.
return self.nodes
def make_weights(self):
self.param_count = sum((node.size for node in self.ordered_nodes))
self.param_count = sum((node.size for node in self.nodes))
self.W = np.zeros(self.param_count, dtype=_f)
self.dW = np.zeros(self.param_count, dtype=_f)
offset = 0
for node in self.ordered_nodes:
for node in self.nodes:
if node.size > 0:
inner_offset = 0
@ -752,39 +805,26 @@ class Model:
assert inner_offset >= node.size, "Layer {} allocated less weights than it said it would".format(node)
offset += node.size
def traverse(self, nodes, node):
if node == self.x:
return [node]
for parent in node.parents:
if parent not in nodes:
new_nodes = self.traverse(nodes, parent)
for new_node in new_nodes:
if new_node not in nodes:
nodes.append(new_node)
if nodes:
nodes.append(node)
return nodes
def forward(self, X, deterministic=False):
values = dict()
input_node = self.ordered_nodes[0]
output_node = self.ordered_nodes[-1]
input_node = self.nodes[0]
output_node = self.nodes[-1]
values[input_node] = input_node._propagate(np.expand_dims(X, 0), deterministic)
for node in self.ordered_nodes[1:]:
for node in self.nodes[1:]:
values[node] = node.propagate(values, deterministic)
return values[output_node]
def backward(self, error):
values = dict()
output_node = self.ordered_nodes[-1]
output_node = self.nodes[-1]
values[output_node] = output_node._backpropagate(np.expand_dims(error, 0))
for node in reversed(self.ordered_nodes[:-1]):
for node in reversed(self.nodes[:-1]):
values[node] = node.backpropagate(values)
return self.dW
def regulate_forward(self):
loss = _0
for node in self.ordered_nodes:
for node in self.nodes:
if node.loss is not None:
loss += node.loss
for k, w in node.weights.items():
@ -792,7 +832,7 @@ class Model:
return loss
def regulate(self):
for node in self.ordered_nodes:
for node in self.nodes:
for k, w in node.weights.items():
w.update()
@ -812,7 +852,7 @@ class Model:
for k in weights.keys():
used[k] = False
nodes = [node for node in self.ordered_nodes if node.size > 0]
nodes = [node for node in self.nodes if node.size > 0]
for node in nodes:
full_name = str(node).lower()
for s_name, o_name in node.serialized.items():
@ -833,7 +873,7 @@ class Model:
counts = defaultdict(lambda: 0)
nodes = [node for node in self.ordered_nodes if node.size > 0]
nodes = [node for node in self.nodes if node.size > 0]
for node in nodes:
full_name = str(node).lower()
grp = f.create_group(full_name)