optim/onn/layer_base.py

144 lines
4.4 KiB
Python

import numpy as np
from collections import defaultdict, OrderedDict
from .weight import *
# used for numbering layers like Keras:
_layer_counters = defaultdict(lambda: 0)
class LayerIncompatibility(Exception):
pass
class Layer:
def __init__(self):
self.parents = []
self.children = []
self.weights = OrderedDict()
self.loss = None # for activity regularizers
self.input_shape = None
self.output_shape = None
kind = self.__class__.__name__
global _layer_counters
_layer_counters[kind] += 1
self.name = "{}_{}".format(kind, _layer_counters[kind])
self.unsafe = False # disables assertions for better performance
self.shared = False # as in weight sharing
def __str__(self):
return self.name
# methods we might want to override:
def forward(self, X):
raise NotImplementedError("unimplemented", self)
def forward_deterministic(self, X):
return self.forward(X)
def backward(self, dY):
raise NotImplementedError("unimplemented", self)
def make_shape(self, parent):
if self.input_shape == None:
self.input_shape = parent.output_shape
if self.output_shape == None:
self.output_shape = self.input_shape
def do_feed(self, child):
self.children.append(child)
def be_fed(self, parent):
self.parents.append(parent)
# TODO: better names for these (still)
def _propagate(self, edges, deterministic):
if not self.unsafe:
assert len(edges) == 1, self
if deterministic:
return self.forward_deterministic(edges[0])
else:
return self.forward(edges[0])
def _backpropagate(self, edges):
if len(edges) == 1:
return self.backward(edges[0])
return sum((self.backward(dY) for dY in edges))
# general utility methods:
def is_compatible(self, parent):
return np.all(self.input_shape == parent.output_shape)
def feed(self, child):
assert self.output_shape is not None, self
child.make_shape(self)
if not child.is_compatible(self):
fmt = "{} is incompatible with {}: shape mismatch: {} vs. {}"
raise LayerIncompatibility(fmt.format(self, child, self.output_shape, child.input_shape))
self.do_feed(child)
child.be_fed(self)
return child
def validate_input(self, X):
assert X.shape[1:] == self.input_shape, (str(self), X.shape[1:], self.input_shape)
def validate_output(self, Y):
assert Y.shape[1:] == self.output_shape, (str(self), Y.shape[1:], self.output_shape)
def _new_weights(self, name, **kwargs):
w = Weights(**kwargs)
assert name not in self.weights, name
self.weights[name] = w
return w
def share(self, node):
self.weights = node.weights # TODO: this should be all it takes.
for k, v in self.weights.items():
vs = getattr(node, k) # hack: key isn't necessarily attribute name!
setattr(self, k, vs)
self.shared = True
def clear_grad(self):
for name, w in self.weights.items():
w.g[:] = 0
@property
def size(self):
return sum((w.size for w in self.weights.values()))
def init(self, allocator):
ins, outs = self.input_shape[0], self.output_shape[0]
for k, w in self.weights.items():
w.allocate(ins, outs, allocator=allocator)
def propagate(self, values, deterministic):
if not self.unsafe:
assert self.parents, self
edges = []
for parent in self.parents:
if parent in values:
X = values[parent]
if not self.unsafe:
self.validate_input(X)
edges.append(X)
Y = self._propagate(edges, deterministic)
if not self.unsafe:
self.validate_output(Y)
return Y
def backpropagate(self, values):
if not self.unsafe:
assert self.children, self
edges = []
for child in self.children:
if child in values:
dY = values[child]
if not self.unsafe:
self.validate_output(dY)
edges.append(dY)
dX = self._backpropagate(edges)
if not self.unsafe:
self.validate_input(dX)
return dX