optim/onn/model.py

214 lines
7.6 KiB
Python
Raw Normal View History

2018-01-21 14:04:25 -08:00
import sys
from .float import _f, _0
from .nodal import *
2018-01-21 14:04:25 -08:00
from .layer_base import *
from .utility import *
2018-01-21 14:04:25 -08:00
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
class Model:
2018-01-22 11:40:36 -08:00
def __init__(self, nodes_in, nodes_out,
loss=None, mloss=None, unsafe=False):
2018-01-21 14:04:25 -08:00
self.loss = loss if loss is not None else SquaredHalved()
self.mloss = mloss if mloss is not None else loss
2018-01-22 11:40:36 -08:00
nodes_in = [nodes_in] if isinstance(nodes_in, Layer) else nodes_in
2018-01-21 14:04:25 -08:00
nodes_out = [nodes_out] if isinstance(nodes_out, Layer) else nodes_out
2018-01-22 11:40:36 -08:00
assert type(nodes_in) == list, type(nodes_in)
2018-01-21 14:04:25 -08:00
assert type(nodes_out) == list, type(nodes_out)
self.nodes_in = nodes_in
self.nodes_out = nodes_out
self.nodes = traverse_all(self.nodes_in, self.nodes_out)
self.make_weights()
for node in self.nodes:
node.unsafe = unsafe
# TODO: handle the same layer being in more than one node.
@property
def ordered_nodes(self):
# deprecated? we don't guarantee an order like we did before.
return self.nodes
def make_weights(self):
2018-01-22 11:40:36 -08:00
self.param_count = sum((node.size for node in self.nodes
if not node.shared))
self.W = np.zeros(self.param_count, dtype=_f)
2018-01-21 14:04:25 -08:00
self.dW = np.zeros(self.param_count, dtype=_f)
offset = 0
for node in self.nodes:
if node.size > 0 and not node.shared:
inner_offset = 0
def allocate(size):
nonlocal inner_offset
o = offset + inner_offset
ret = self.W[o:o+size], self.dW[o:o+size]
inner_offset += size
assert len(ret[0]) == len(ret[1])
assert size == len(ret[0]), (size, len(ret[0]))
return ret
2018-01-22 11:40:36 -08:00
fmt = "Layer {} allocated {} weights than it said it would"
2018-01-21 14:04:25 -08:00
node.init(allocate)
2018-01-22 11:40:36 -08:00
assert inner_offset <= node.size, fmt.format("more", node)
2018-01-21 14:04:25 -08:00
# i don't care if "less" is grammatically incorrect.
# you're mom is grammatically incorrect.
2018-01-22 11:40:36 -08:00
assert inner_offset >= node.size, fmt.format("less", node)
2018-01-21 14:04:25 -08:00
offset += node.size
def evaluate(self, input_, deterministic=True):
2018-01-22 11:40:36 -08:00
fmt = "ambiguous input in multi-{} network; use {}() instead"
assert len(self.nodes_in) == 1, fmt.format("input", "evaluate_multi")
assert len(self.nodes_out) == 1, fmt.format("output", "evaluate_multi")
2018-01-21 14:04:25 -08:00
node_in = self.nodes_in[0]
node_out = self.nodes_out[0]
outputs = self.evaluate_multi({node_in: input_}, deterministic)
return outputs[node_out]
2018-01-22 11:40:36 -08:00
def apply(self, error): # TODO: better name?
fmt = "ambiguous input in multi-{} network; use {}() instead"
assert len(self.nodes_in) == 1, fmt.format("input", "apply_multi")
assert len(self.nodes_out) == 1, fmt.format("output", "apply_multi")
2018-01-21 14:04:25 -08:00
node_in = self.nodes_in[0]
node_out = self.nodes_out[0]
inputs = self.apply_multi({node_out: error})
return inputs[node_in]
def evaluate_multi(self, inputs, deterministic=True):
2018-01-22 11:40:36 -08:00
fmt = "missing {} for node {}"
2018-01-21 14:04:25 -08:00
values = dict()
outputs = dict()
for node in self.nodes:
if node in self.nodes_in:
2018-01-22 11:40:36 -08:00
assert node in inputs, fmt.format("input", node.name)
2018-01-21 14:04:25 -08:00
X = inputs[node]
2018-01-22 11:40:36 -08:00
values[node] = node._propagate(np.expand_dims(X, 0),
deterministic)
2018-01-21 14:04:25 -08:00
else:
values[node] = node.propagate(values, deterministic)
if node in self.nodes_out:
outputs[node] = values[node]
return outputs
def apply_multi(self, outputs):
2018-01-22 11:40:36 -08:00
fmt = "missing {} for node {}"
2018-01-21 14:04:25 -08:00
values = dict()
inputs = dict()
for node in reversed(self.nodes):
if node in self.nodes_out:
2018-01-22 11:40:36 -08:00
assert node in outputs, fmt.format("output", node.name)
2018-01-21 14:04:25 -08:00
X = outputs[node]
values[node] = node._backpropagate(np.expand_dims(X, 0))
else:
values[node] = node.backpropagate(values)
if node in self.nodes_in:
inputs[node] = values[node]
return inputs
def forward(self, inputs, outputs, measure=False, deterministic=False):
predicted = self.evaluate(inputs, deterministic=deterministic)
if measure:
error = self.mloss.forward(predicted, outputs)
else:
error = self.loss.forward(predicted, outputs)
return error, predicted
def backward(self, predicted, outputs, measure=False):
if measure:
error = self.mloss.backward(predicted, outputs)
else:
error = self.loss.backward(predicted, outputs)
# input_delta is rarely useful; it's just to match the forward pass.
input_delta = self.apply(error)
return self.dW, input_delta
def clear_grad(self):
for node in self.nodes:
node.clear_grad()
def regulate_forward(self):
loss = _0
for node in self.nodes:
if node.loss is not None:
loss += node.loss
for k, w in node.weights.items():
loss += w.forward()
return loss
def regulate(self):
for node in self.nodes:
for k, w in node.weights.items():
w.update()
def load_weights(self, fn):
# seemingly compatible with keras' Dense layers.
2018-01-22 11:40:36 -08:00
weights = {}
2018-01-21 14:04:25 -08:00
import h5py
2018-01-22 11:40:36 -08:00
open(fn) # just ensure the file exists (python's error is better)
2018-01-21 14:04:25 -08:00
f = h5py.File(fn, 'r')
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
def visitor(name, obj):
if isinstance(obj, h5py.Dataset):
weights[name.split('/')[-1]] = np.array(obj[:], dtype=_f)
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
f.visititems(visitor)
f.close()
used = {}
for k in weights.keys():
used[k] = False
nodes = [node for node in self.nodes if node.size > 0]
# TODO: support shared weights.
for node in nodes:
full_name = str(node).lower()
for s_name, o_name in node.serialized.items():
key = full_name + '_' + s_name
data = weights[key]
target = getattr(node, o_name)
target.f[:] = data
used[key] = True
for k, v in used.items():
if not v:
lament("WARNING: unused weight", k)
def save_weights(self, fn, overwrite=False):
import h5py
from collections import defaultdict
f = h5py.File(fn, 'w')
counts = defaultdict(lambda: 0)
nodes = [node for node in self.nodes if node.size > 0]
# TODO: support shared weights.
for node in nodes:
full_name = str(node).lower()
grp = f.create_group(full_name)
for s_name, o_name in node.serialized.items():
key = full_name + '_' + s_name
target = getattr(node, o_name)
data = grp.create_dataset(key, target.shape, dtype=_f)
data[:] = target.f
counts[key] += 1
if counts[key] > 1:
lament("WARNING: rewrote weight", key)
f.close()
def print_graph(self, file=sys.stdout):
print('digraph G {', file=file)
for node in self.nodes:
children = [str(n) for n in node.children]
if children:
sep = '->'
2018-01-22 11:40:36 -08:00
print('\t' + str(node) + sep +
(';\n\t' + str(node) + sep).join(children) + ';',
file=file)
2018-01-21 14:04:25 -08:00
print('}', file=file)