2018-01-21 14:04:25 -08:00
|
|
|
import sys
|
|
|
|
|
2018-03-17 06:09:15 -07:00
|
|
|
from .float import _f, _0
|
2018-01-21 14:16:36 -08:00
|
|
|
from .nodal import *
|
2018-01-21 14:04:25 -08:00
|
|
|
from .layer_base import *
|
2018-01-21 14:16:36 -08:00
|
|
|
from .utility import *
|
2018-01-21 14:04:25 -08:00
|
|
|
|
2018-01-22 11:40:36 -08:00
|
|
|
|
2018-01-21 14:04:25 -08:00
|
|
|
class Model:
|
2018-01-22 11:40:36 -08:00
|
|
|
def __init__(self, nodes_in, nodes_out,
|
|
|
|
loss=None, mloss=None, unsafe=False):
|
2018-01-21 14:04:25 -08:00
|
|
|
self.loss = loss if loss is not None else SquaredHalved()
|
|
|
|
self.mloss = mloss if mloss is not None else loss
|
|
|
|
|
2018-01-22 11:40:36 -08:00
|
|
|
nodes_in = [nodes_in] if isinstance(nodes_in, Layer) else nodes_in
|
2018-01-21 14:04:25 -08:00
|
|
|
nodes_out = [nodes_out] if isinstance(nodes_out, Layer) else nodes_out
|
2018-01-22 11:40:36 -08:00
|
|
|
assert type(nodes_in) == list, type(nodes_in)
|
2018-01-21 14:04:25 -08:00
|
|
|
assert type(nodes_out) == list, type(nodes_out)
|
|
|
|
self.nodes_in = nodes_in
|
|
|
|
self.nodes_out = nodes_out
|
|
|
|
|
|
|
|
self.nodes = traverse_all(self.nodes_in, self.nodes_out)
|
|
|
|
self.make_weights()
|
|
|
|
for node in self.nodes:
|
|
|
|
node.unsafe = unsafe
|
|
|
|
# TODO: handle the same layer being in more than one node.
|
|
|
|
|
|
|
|
@property
|
|
|
|
def ordered_nodes(self):
|
|
|
|
# deprecated? we don't guarantee an order like we did before.
|
|
|
|
return self.nodes
|
|
|
|
|
|
|
|
def make_weights(self):
|
2018-01-22 11:40:36 -08:00
|
|
|
self.param_count = sum((node.size for node in self.nodes
|
|
|
|
if not node.shared))
|
|
|
|
self.W = np.zeros(self.param_count, dtype=_f)
|
2018-01-21 14:04:25 -08:00
|
|
|
self.dW = np.zeros(self.param_count, dtype=_f)
|
|
|
|
|
|
|
|
offset = 0
|
|
|
|
for node in self.nodes:
|
|
|
|
if node.size > 0 and not node.shared:
|
|
|
|
inner_offset = 0
|
|
|
|
|
|
|
|
def allocate(size):
|
|
|
|
nonlocal inner_offset
|
|
|
|
o = offset + inner_offset
|
|
|
|
ret = self.W[o:o+size], self.dW[o:o+size]
|
|
|
|
inner_offset += size
|
|
|
|
assert len(ret[0]) == len(ret[1])
|
|
|
|
assert size == len(ret[0]), (size, len(ret[0]))
|
|
|
|
return ret
|
|
|
|
|
2018-01-22 11:40:36 -08:00
|
|
|
fmt = "Layer {} allocated {} weights than it said it would"
|
2018-01-21 14:04:25 -08:00
|
|
|
node.init(allocate)
|
2018-01-22 11:40:36 -08:00
|
|
|
assert inner_offset <= node.size, fmt.format("more", node)
|
2018-01-21 14:04:25 -08:00
|
|
|
# i don't care if "less" is grammatically incorrect.
|
|
|
|
# you're mom is grammatically incorrect.
|
2018-01-22 11:40:36 -08:00
|
|
|
assert inner_offset >= node.size, fmt.format("less", node)
|
2018-01-21 14:04:25 -08:00
|
|
|
offset += node.size
|
|
|
|
|
|
|
|
def evaluate(self, input_, deterministic=True):
|
2018-01-22 11:40:36 -08:00
|
|
|
fmt = "ambiguous input in multi-{} network; use {}() instead"
|
|
|
|
assert len(self.nodes_in) == 1, fmt.format("input", "evaluate_multi")
|
|
|
|
assert len(self.nodes_out) == 1, fmt.format("output", "evaluate_multi")
|
2018-01-21 14:04:25 -08:00
|
|
|
node_in = self.nodes_in[0]
|
|
|
|
node_out = self.nodes_out[0]
|
|
|
|
outputs = self.evaluate_multi({node_in: input_}, deterministic)
|
|
|
|
return outputs[node_out]
|
|
|
|
|
2018-01-22 11:40:36 -08:00
|
|
|
def apply(self, error): # TODO: better name?
|
|
|
|
fmt = "ambiguous input in multi-{} network; use {}() instead"
|
|
|
|
assert len(self.nodes_in) == 1, fmt.format("input", "apply_multi")
|
|
|
|
assert len(self.nodes_out) == 1, fmt.format("output", "apply_multi")
|
2018-01-21 14:04:25 -08:00
|
|
|
node_in = self.nodes_in[0]
|
|
|
|
node_out = self.nodes_out[0]
|
|
|
|
inputs = self.apply_multi({node_out: error})
|
|
|
|
return inputs[node_in]
|
|
|
|
|
|
|
|
def evaluate_multi(self, inputs, deterministic=True):
|
2018-01-22 11:40:36 -08:00
|
|
|
fmt = "missing {} for node {}"
|
2018-01-21 14:04:25 -08:00
|
|
|
values = dict()
|
|
|
|
outputs = dict()
|
|
|
|
for node in self.nodes:
|
|
|
|
if node in self.nodes_in:
|
2018-01-22 11:40:36 -08:00
|
|
|
assert node in inputs, fmt.format("input", node.name)
|
2018-01-21 14:04:25 -08:00
|
|
|
X = inputs[node]
|
2018-01-22 11:40:36 -08:00
|
|
|
values[node] = node._propagate(np.expand_dims(X, 0),
|
|
|
|
deterministic)
|
2018-01-21 14:04:25 -08:00
|
|
|
else:
|
|
|
|
values[node] = node.propagate(values, deterministic)
|
|
|
|
if node in self.nodes_out:
|
|
|
|
outputs[node] = values[node]
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
def apply_multi(self, outputs):
|
2018-01-22 11:40:36 -08:00
|
|
|
fmt = "missing {} for node {}"
|
2018-01-21 14:04:25 -08:00
|
|
|
values = dict()
|
|
|
|
inputs = dict()
|
|
|
|
for node in reversed(self.nodes):
|
|
|
|
if node in self.nodes_out:
|
2018-01-22 11:40:36 -08:00
|
|
|
assert node in outputs, fmt.format("output", node.name)
|
2018-01-21 14:04:25 -08:00
|
|
|
X = outputs[node]
|
|
|
|
values[node] = node._backpropagate(np.expand_dims(X, 0))
|
|
|
|
else:
|
|
|
|
values[node] = node.backpropagate(values)
|
|
|
|
if node in self.nodes_in:
|
|
|
|
inputs[node] = values[node]
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
def forward(self, inputs, outputs, measure=False, deterministic=False):
|
|
|
|
predicted = self.evaluate(inputs, deterministic=deterministic)
|
|
|
|
if measure:
|
|
|
|
error = self.mloss.forward(predicted, outputs)
|
|
|
|
else:
|
|
|
|
error = self.loss.forward(predicted, outputs)
|
|
|
|
return error, predicted
|
|
|
|
|
|
|
|
def backward(self, predicted, outputs, measure=False):
|
|
|
|
if measure:
|
|
|
|
error = self.mloss.backward(predicted, outputs)
|
|
|
|
else:
|
|
|
|
error = self.loss.backward(predicted, outputs)
|
|
|
|
# input_delta is rarely useful; it's just to match the forward pass.
|
|
|
|
input_delta = self.apply(error)
|
|
|
|
return self.dW, input_delta
|
|
|
|
|
|
|
|
def clear_grad(self):
|
|
|
|
for node in self.nodes:
|
|
|
|
node.clear_grad()
|
|
|
|
|
|
|
|
def regulate_forward(self):
|
|
|
|
loss = _0
|
|
|
|
for node in self.nodes:
|
|
|
|
if node.loss is not None:
|
|
|
|
loss += node.loss
|
|
|
|
for k, w in node.weights.items():
|
|
|
|
loss += w.forward()
|
|
|
|
return loss
|
|
|
|
|
|
|
|
def regulate(self):
|
|
|
|
for node in self.nodes:
|
|
|
|
for k, w in node.weights.items():
|
|
|
|
w.update()
|
|
|
|
|
|
|
|
def load_weights(self, fn):
|
|
|
|
# seemingly compatible with keras' Dense layers.
|
2018-01-22 11:40:36 -08:00
|
|
|
weights = {}
|
|
|
|
|
2018-01-21 14:04:25 -08:00
|
|
|
import h5py
|
2018-01-22 11:40:36 -08:00
|
|
|
open(fn) # just ensure the file exists (python's error is better)
|
|
|
|
|
2018-01-21 14:04:25 -08:00
|
|
|
f = h5py.File(fn, 'r')
|
2018-01-22 11:40:36 -08:00
|
|
|
|
2018-01-21 14:04:25 -08:00
|
|
|
def visitor(name, obj):
|
|
|
|
if isinstance(obj, h5py.Dataset):
|
|
|
|
weights[name.split('/')[-1]] = np.array(obj[:], dtype=_f)
|
2018-01-22 11:40:36 -08:00
|
|
|
|
2018-01-21 14:04:25 -08:00
|
|
|
f.visititems(visitor)
|
|
|
|
f.close()
|
|
|
|
|
|
|
|
used = {}
|
|
|
|
for k in weights.keys():
|
|
|
|
used[k] = False
|
|
|
|
|
|
|
|
nodes = [node for node in self.nodes if node.size > 0]
|
|
|
|
# TODO: support shared weights.
|
|
|
|
for node in nodes:
|
|
|
|
full_name = str(node).lower()
|
|
|
|
for s_name, o_name in node.serialized.items():
|
|
|
|
key = full_name + '_' + s_name
|
|
|
|
data = weights[key]
|
|
|
|
target = getattr(node, o_name)
|
|
|
|
target.f[:] = data
|
|
|
|
used[key] = True
|
|
|
|
|
|
|
|
for k, v in used.items():
|
|
|
|
if not v:
|
|
|
|
lament("WARNING: unused weight", k)
|
|
|
|
|
|
|
|
def save_weights(self, fn, overwrite=False):
|
|
|
|
import h5py
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
|
|
f = h5py.File(fn, 'w')
|
|
|
|
|
|
|
|
counts = defaultdict(lambda: 0)
|
|
|
|
|
|
|
|
nodes = [node for node in self.nodes if node.size > 0]
|
|
|
|
# TODO: support shared weights.
|
|
|
|
for node in nodes:
|
|
|
|
full_name = str(node).lower()
|
|
|
|
grp = f.create_group(full_name)
|
|
|
|
for s_name, o_name in node.serialized.items():
|
|
|
|
key = full_name + '_' + s_name
|
|
|
|
target = getattr(node, o_name)
|
|
|
|
data = grp.create_dataset(key, target.shape, dtype=_f)
|
|
|
|
data[:] = target.f
|
|
|
|
counts[key] += 1
|
|
|
|
if counts[key] > 1:
|
|
|
|
lament("WARNING: rewrote weight", key)
|
|
|
|
|
|
|
|
f.close()
|
|
|
|
|
|
|
|
def print_graph(self, file=sys.stdout):
|
|
|
|
print('digraph G {', file=file)
|
|
|
|
for node in self.nodes:
|
|
|
|
children = [str(n) for n in node.children]
|
|
|
|
if children:
|
|
|
|
sep = '->'
|
2018-01-22 11:40:36 -08:00
|
|
|
print('\t' + str(node) + sep +
|
|
|
|
(';\n\t' + str(node) + sep).join(children) + ';',
|
|
|
|
file=file)
|
2018-01-21 14:04:25 -08:00
|
|
|
print('}', file=file)
|