150 lines
4.4 KiB
Python
150 lines
4.4 KiB
Python
import numpy as np
|
|
|
|
from collections import defaultdict, OrderedDict
|
|
|
|
from .weight import *
|
|
|
|
|
|
# used for numbering layers like Keras:
|
|
_layer_counters = defaultdict(lambda: 0)
|
|
|
|
|
|
class LayerIncompatibility(Exception):
|
|
pass
|
|
|
|
|
|
class Layer:
|
|
def __init__(self):
|
|
self.parents = []
|
|
self.children = []
|
|
self.weights = OrderedDict()
|
|
self.loss = None # for activity regularizers
|
|
self.input_shape = None
|
|
self.output_shape = None
|
|
kind = self.__class__.__name__
|
|
global _layer_counters
|
|
_layer_counters[kind] += 1
|
|
self.name = "{}_{}".format(kind, _layer_counters[kind])
|
|
self.unsafe = False # disables assertions for better performance
|
|
self.shared = False # as in weight sharing
|
|
|
|
def __str__(self):
|
|
return self.name
|
|
|
|
# methods we might want to override:
|
|
|
|
def forward(self, X):
|
|
raise NotImplementedError("unimplemented", self)
|
|
|
|
def forward_deterministic(self, X):
|
|
return self.forward(X)
|
|
|
|
def backward(self, dY):
|
|
raise NotImplementedError("unimplemented", self)
|
|
|
|
def make_shape(self, parent):
|
|
if self.input_shape is None:
|
|
self.input_shape = parent.output_shape
|
|
if self.output_shape is None:
|
|
self.output_shape = self.input_shape
|
|
|
|
def do_feed(self, child):
|
|
self.children.append(child)
|
|
|
|
def be_fed(self, parent):
|
|
self.parents.append(parent)
|
|
|
|
# TODO: better names for these (still)
|
|
def _propagate(self, edges, deterministic):
|
|
if not self.unsafe:
|
|
assert len(edges) == 1, self
|
|
if deterministic:
|
|
return self.forward_deterministic(edges[0])
|
|
else:
|
|
return self.forward(edges[0])
|
|
|
|
def _backpropagate(self, edges):
|
|
if len(edges) == 1:
|
|
return self.backward(edges[0])
|
|
return sum((self.backward(dY) for dY in edges))
|
|
|
|
# general utility methods:
|
|
|
|
def is_compatible(self, parent):
|
|
return np.all(self.input_shape == parent.output_shape)
|
|
|
|
def feed(self, child):
|
|
assert self.output_shape is not None, self
|
|
child.make_shape(self)
|
|
if not child.is_compatible(self):
|
|
fmt = "{} is incompatible with {}: shape mismatch: {} vs. {}"
|
|
raise LayerIncompatibility(fmt.format(
|
|
self, child, self.output_shape, child.input_shape))
|
|
self.do_feed(child)
|
|
child.be_fed(self)
|
|
return child
|
|
|
|
def validate_input(self, X):
|
|
assert X.shape[1:] == self.input_shape, \
|
|
(str(self), X.shape[1:], self.input_shape)
|
|
|
|
def validate_output(self, Y):
|
|
assert Y.shape[1:] == self.output_shape, \
|
|
(str(self), Y.shape[1:], self.output_shape)
|
|
|
|
def _new_weights(self, name, **kwargs):
|
|
w = Weights(**kwargs)
|
|
assert name not in self.weights, name
|
|
self.weights[name] = w
|
|
return w
|
|
|
|
def share(self, node):
|
|
self.weights = node.weights # TODO: this should be all it takes.
|
|
for k, v in self.weights.items():
|
|
# hack: key isn't necessarily attribute name!
|
|
vs = getattr(node, k)
|
|
setattr(self, k, vs)
|
|
self.shared = True
|
|
|
|
def clear_grad(self):
|
|
for name, w in self.weights.items():
|
|
w.g[:] = 0
|
|
|
|
@property
|
|
def size(self):
|
|
return sum((w.size for w in self.weights.values()))
|
|
|
|
def init(self, allocator):
|
|
ins, outs = self.input_shape[0], self.output_shape[0]
|
|
for k, w in self.weights.items():
|
|
w.allocate(ins, outs, allocator=allocator)
|
|
|
|
def propagate(self, values, deterministic):
|
|
if not self.unsafe:
|
|
assert self.parents, self
|
|
edges = []
|
|
for parent in self.parents:
|
|
if parent in values:
|
|
X = values[parent]
|
|
if not self.unsafe:
|
|
self.validate_input(X)
|
|
edges.append(X)
|
|
Y = self._propagate(edges, deterministic)
|
|
if not self.unsafe:
|
|
self.validate_output(Y)
|
|
return Y
|
|
|
|
def backpropagate(self, values):
|
|
if not self.unsafe:
|
|
assert self.children, self
|
|
edges = []
|
|
for child in self.children:
|
|
if child in values:
|
|
dY = values[child]
|
|
if not self.unsafe:
|
|
self.validate_output(dY)
|
|
edges.append(dY)
|
|
dX = self._backpropagate(edges)
|
|
if not self.unsafe:
|
|
self.validate_input(dX)
|
|
return dX
|