begin support for multiple input/output layers
This commit is contained in:
parent
a7c4bdaa2e
commit
a530558fc1
1 changed files with 63 additions and 31 deletions
94
onn_core.py
94
onn_core.py
|
@ -24,6 +24,44 @@ _pi = _f(np.pi)
|
||||||
class LayerIncompatibility(Exception):
|
class LayerIncompatibility(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Node Traversal {{{1
|
||||||
|
|
||||||
|
class DummyNode:
|
||||||
|
def __init__(self, children=None, parents=None):
|
||||||
|
self.children = children if children is not None else []
|
||||||
|
self.parents = parents if parents is not None else []
|
||||||
|
|
||||||
|
def levelorder(field, node_in, nodes=None):
|
||||||
|
# relatively inefficient. this function can be optimized.
|
||||||
|
nodes = nodes if nodes is not None else []
|
||||||
|
q = [node_in]
|
||||||
|
while len(q) > 0:
|
||||||
|
node = q.pop(0)
|
||||||
|
nodes.append(node)
|
||||||
|
for child in getattr(node, field):
|
||||||
|
q.append(child)
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
def traverse(node_in, node_out, nodes):
|
||||||
|
nodes = nodes if nodes is not None else []
|
||||||
|
down = levelorder('children', node_in)
|
||||||
|
up = levelorder('parents', node_out)
|
||||||
|
seen = {}
|
||||||
|
for node in up:
|
||||||
|
seen[node] = seen.get(node, 0) | 1
|
||||||
|
for node in down:
|
||||||
|
seen[node] = seen.get(node, 0) | 2
|
||||||
|
if seen[node] == 3:
|
||||||
|
nodes.append(node)
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
def traverse_all(nodes_in, nodes_out, nodes=None):
|
||||||
|
all_in = DummyNode()
|
||||||
|
all_out = DummyNode()
|
||||||
|
for node in nodes_in: all_in.children.append(node)
|
||||||
|
for node in nodes_out: all_out.parents.append(node)
|
||||||
|
return traverse(all_in, all_out, nodes)
|
||||||
|
|
||||||
# Initializations {{{1
|
# Initializations {{{1
|
||||||
|
|
||||||
# note: these are currently only implemented for 2D shapes.
|
# note: these are currently only implemented for 2D shapes.
|
||||||
|
@ -716,23 +754,30 @@ class Dense(Layer):
|
||||||
# Models {{{1
|
# Models {{{1
|
||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
def __init__(self, x, y, unsafe=False):
|
def __init__(self, nodes_in, nodes_out, unsafe=False):
|
||||||
assert isinstance(x, Layer), x
|
nodes_in = [nodes_in] if isinstance(nodes_in, Layer) else nodes_in
|
||||||
assert isinstance(y, Layer), y
|
nodes_out = [nodes_out] if isinstance(nodes_out, Layer) else nodes_out
|
||||||
self.x = x
|
assert type(nodes_in) == list, type(nodes_in)
|
||||||
self.y = y
|
assert type(nodes_out) == list, type(nodes_out)
|
||||||
self.ordered_nodes = self.traverse([], self.y)
|
self.nodes_in = nodes_in
|
||||||
|
self.nodes_out = nodes_out
|
||||||
|
self.nodes = traverse_all(self.nodes_in, self.nodes_out)
|
||||||
self.make_weights()
|
self.make_weights()
|
||||||
for node in self.ordered_nodes:
|
for node in self.nodes:
|
||||||
node.unsafe = unsafe
|
node.unsafe = unsafe
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ordered_nodes(self):
|
||||||
|
# deprecated? we don't guarantee an order like we did before.
|
||||||
|
return self.nodes
|
||||||
|
|
||||||
def make_weights(self):
|
def make_weights(self):
|
||||||
self.param_count = sum((node.size for node in self.ordered_nodes))
|
self.param_count = sum((node.size for node in self.nodes))
|
||||||
self.W = np.zeros(self.param_count, dtype=_f)
|
self.W = np.zeros(self.param_count, dtype=_f)
|
||||||
self.dW = np.zeros(self.param_count, dtype=_f)
|
self.dW = np.zeros(self.param_count, dtype=_f)
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
for node in self.ordered_nodes:
|
for node in self.nodes:
|
||||||
if node.size > 0:
|
if node.size > 0:
|
||||||
inner_offset = 0
|
inner_offset = 0
|
||||||
|
|
||||||
|
@ -752,39 +797,26 @@ class Model:
|
||||||
assert inner_offset >= node.size, "Layer {} allocated less weights than it said it would".format(node)
|
assert inner_offset >= node.size, "Layer {} allocated less weights than it said it would".format(node)
|
||||||
offset += node.size
|
offset += node.size
|
||||||
|
|
||||||
def traverse(self, nodes, node):
|
|
||||||
if node == self.x:
|
|
||||||
return [node]
|
|
||||||
for parent in node.parents:
|
|
||||||
if parent not in nodes:
|
|
||||||
new_nodes = self.traverse(nodes, parent)
|
|
||||||
for new_node in new_nodes:
|
|
||||||
if new_node not in nodes:
|
|
||||||
nodes.append(new_node)
|
|
||||||
if nodes:
|
|
||||||
nodes.append(node)
|
|
||||||
return nodes
|
|
||||||
|
|
||||||
def forward(self, X, deterministic=False):
|
def forward(self, X, deterministic=False):
|
||||||
values = dict()
|
values = dict()
|
||||||
input_node = self.ordered_nodes[0]
|
input_node = self.nodes[0]
|
||||||
output_node = self.ordered_nodes[-1]
|
output_node = self.nodes[-1]
|
||||||
values[input_node] = input_node._propagate(np.expand_dims(X, 0), deterministic)
|
values[input_node] = input_node._propagate(np.expand_dims(X, 0), deterministic)
|
||||||
for node in self.ordered_nodes[1:]:
|
for node in self.nodes[1:]:
|
||||||
values[node] = node.propagate(values, deterministic)
|
values[node] = node.propagate(values, deterministic)
|
||||||
return values[output_node]
|
return values[output_node]
|
||||||
|
|
||||||
def backward(self, error):
|
def backward(self, error):
|
||||||
values = dict()
|
values = dict()
|
||||||
output_node = self.ordered_nodes[-1]
|
output_node = self.nodes[-1]
|
||||||
values[output_node] = output_node._backpropagate(np.expand_dims(error, 0))
|
values[output_node] = output_node._backpropagate(np.expand_dims(error, 0))
|
||||||
for node in reversed(self.ordered_nodes[:-1]):
|
for node in reversed(self.nodes[:-1]):
|
||||||
values[node] = node.backpropagate(values)
|
values[node] = node.backpropagate(values)
|
||||||
return self.dW
|
return self.dW
|
||||||
|
|
||||||
def regulate_forward(self):
|
def regulate_forward(self):
|
||||||
loss = _0
|
loss = _0
|
||||||
for node in self.ordered_nodes:
|
for node in self.nodes:
|
||||||
if node.loss is not None:
|
if node.loss is not None:
|
||||||
loss += node.loss
|
loss += node.loss
|
||||||
for k, w in node.weights.items():
|
for k, w in node.weights.items():
|
||||||
|
@ -792,7 +824,7 @@ class Model:
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def regulate(self):
|
def regulate(self):
|
||||||
for node in self.ordered_nodes:
|
for node in self.nodes:
|
||||||
for k, w in node.weights.items():
|
for k, w in node.weights.items():
|
||||||
w.update()
|
w.update()
|
||||||
|
|
||||||
|
@ -812,7 +844,7 @@ class Model:
|
||||||
for k in weights.keys():
|
for k in weights.keys():
|
||||||
used[k] = False
|
used[k] = False
|
||||||
|
|
||||||
nodes = [node for node in self.ordered_nodes if node.size > 0]
|
nodes = [node for node in self.nodes if node.size > 0]
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
full_name = str(node).lower()
|
full_name = str(node).lower()
|
||||||
for s_name, o_name in node.serialized.items():
|
for s_name, o_name in node.serialized.items():
|
||||||
|
@ -833,7 +865,7 @@ class Model:
|
||||||
|
|
||||||
counts = defaultdict(lambda: 0)
|
counts = defaultdict(lambda: 0)
|
||||||
|
|
||||||
nodes = [node for node in self.ordered_nodes if node.size > 0]
|
nodes = [node for node in self.nodes if node.size > 0]
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
full_name = str(node).lower()
|
full_name = str(node).lower()
|
||||||
grp = f.create_group(full_name)
|
grp = f.create_group(full_name)
|
||||||
|
|
Loading…
Reference in a new issue