greatly refactor weight handling

This commit is contained in:
Connor Olding 2017-04-10 09:53:54 +00:00
parent c3e2dd56bf
commit a78fc98215
2 changed files with 128 additions and 108 deletions

View file

@ -93,9 +93,10 @@ class LayerNorm(Layer):
super().__init__()
self.eps = _f(eps)
self.affine = bool(affine)
self.size = None
if self.affine:
self.gamma = self._new_weights('gamma', init=init_ones)
self.beta = self._new_weights('beta', init=init_zeros)
self.serialized = {
'gamma': 'gamma',
'beta': 'beta',
@ -104,23 +105,12 @@ class LayerNorm(Layer):
def make_shape(self, parent):
shape = parent.output_shape
self.input_shape = shape
self.output_shape = shape
if len(shape) != 1:
return False
self.features = shape[0]
if self.affine:
self.size = 2 * self.features
return shape
def init(self, W, dW):
super().init(W, dW)
f = self.features
self.gamma, self.dgamma = self.W[0*f:1*f], self.dW[0*f:1*f]
self.beta, self.dbeta = self.W[1*f:2*f], self.dW[1*f:2*f]
self.gamma[:] = 1
self.beta[:] = 0
self.gamma.shape = (shape[0],)
self.beta.shape = (shape[0],)
def forward(self, X):
self.mean = X.mean(0)
@ -130,16 +120,16 @@ class LayerNorm(Layer):
self.Xnorm = self.center / self.std
if self.affine:
return self.gamma * self.Xnorm + self.beta
return self.gamma.f * self.Xnorm + self.beta.f
return self.Xnorm
def backward(self, dY):
length = dY.shape[0]
if self.affine:
dXnorm = dY * self.gamma
self.dgamma[:] = (dY * self.Xnorm).sum(0)
self.dbeta[:] = dY.sum(0)
dXnorm = dY * self.gamma.f
self.gamma.g[:] = (dY * self.Xnorm).sum(0)
self.beta.g[:] = dY.sum(0)
else:
dXnorm = dY
@ -163,7 +153,8 @@ class Denses(Layer): # TODO: rename?
self.dim = int(dim)
self.weight_init = init
self.axis = int(axis)
self.size = None
self.coeffs = self._new_weights('coeffs', init=init)
self.biases = self._new_weights('biases', init=init_zeros)
def make_shape(self, parent):
shape = parent.output_shape
@ -178,64 +169,46 @@ class Denses(Layer): # TODO: rename?
self.output_shape[self.axis] = self.dim
self.output_shape = tuple(self.output_shape)
self.nW = self.dim * np.prod(shape)
self.nb = np.prod(self.output_shape)
self.size = self.nW + self.nb
return shape
def init(self, W, dW):
super().init(W, dW)
ins, outs = np.prod(self.input_shape), np.prod(self.output_shape)
in_rows = self.input_shape[0]
in_cols = self.input_shape[1]
out_rows = self.output_shape[0]
out_cols = self.output_shape[1]
self.coeffs = self.W[:self.nW].reshape(in_rows, in_cols, self.dim)
self.biases = self.W[self.nW:].reshape(1, out_rows, out_cols)
self.dcoeffs = self.dW[:self.nW].reshape(self.coeffs.shape)
self.dbiases = self.dW[self.nW:].reshape(self.biases.shape)
self.coeffs.flat = self.weight_init(self.nW, ins, outs)
self.biases.flat = 0
self.std = np.std(self.W)
self.coeffs.shape = (in_rows, in_cols, self.dim)
self.biases.shape = (1, out_rows, out_cols)
def forward(self, X):
self.X = X
if self.axis == 0:
return np.einsum('ixj,xjk->ikj', X, self.coeffs) + self.biases
return np.einsum('ixj,xjk->ikj', X, self.coeffs.f) + self.biases.f
elif self.axis == 1:
return np.einsum('ijx,jxk->ijk', X, self.coeffs) + self.biases
return np.einsum('ijx,jxk->ijk', X, self.coeffs.f) + self.biases.f
def backward(self, dY):
self.dbiases[:] = dY.sum(0, keepdims=True)
self.biases.g[:] = dY.sum(0, keepdims=True)
if self.axis == 0:
self.dcoeffs[:] = np.einsum('ixj,ikj->xjk', self.X, dY)
return np.einsum('ikj,xjk->ixj', dY, self.coeffs)
self.coeffs.g[:] = np.einsum('ixj,ikj->xjk', self.X, dY)
return np.einsum('ikj,xjk->ixj', dY, self.coeffs.f)
elif self.axis == 1:
self.dcoeffs[:] = np.einsum('ijx,ijk->jxk', self.X, dY)
return np.einsum('ijk,jxk->ijx', dY, self.coeffs)
self.coeffs.g[:] = np.einsum('ijx,ijk->jxk', self.X, dY)
return np.einsum('ijk,jxk->ijx', dY, self.coeffs.f)
class DenseOneLess(Dense):
def init(self, W, dW):
super().init(W, dW)
def init(self, allocator):
super().init(allocator)
ins, outs = self.input_shape[0], self.output_shape[0]
assert ins == outs, (ins, outs)
def forward(self, X):
np.fill_diagonal(self.coeffs, 0)
np.fill_diagonal(self.coeffs.f, 0)
self.X = X
return X.dot(self.coeffs) + self.biases
return X.dot(self.coeffs.f) + self.biases
def backward(self, dY):
self.dcoeffs[:] = self.X.T.dot(dY)
self.dbiases[:] = dY.sum(0, keepdims=True)
np.fill_diagonal(self.dcoeffs, 0)
return dY.dot(self.coeffs.T)
self.coeffs.g[:] = self.X.T.dot(dY)
self.biases.g[:] = dY.sum(0, keepdims=True)
np.fill_diagonal(self.coeffs.g, 0)
return dY.dot(self.coeffs.f.T)
class CosineDense(Dense):
# paper: https://arxiv.org/abs/1702.05870
@ -250,9 +223,9 @@ class CosineDense(Dense):
self.X = X
self.X_norm = np.sqrt(np.square(X).sum(-1, keepdims=True) \
+ 1 + self.eps)
self.W_norm = np.sqrt(np.square(self.coeffs).sum(0, keepdims=True) \
+ np.square(self.biases) + self.eps)
self.dot = X.dot(self.coeffs) + self.biases
self.W_norm = np.sqrt(np.square(self.coeffs.f).sum(0, keepdims=True) \
+ np.square(self.biases.f) + self.eps)
self.dot = X.dot(self.coeffs.f) + self.biases.f
Y = self.dot / (self.X_norm * self.W_norm)
return Y
@ -261,11 +234,11 @@ class CosineDense(Dense):
dX_norm = -(dY * self.dot / self.W_norm).sum(-1, keepdims=True) / self.X_norm**2
dW_norm = -(dY * self.dot / self.X_norm).sum( 0, keepdims=True) / self.W_norm**2
self.dcoeffs[:] = self.X.T.dot(ddot) \
+ dW_norm / self.W_norm * self.coeffs
self.dbiases[:] = ddot.sum(0, keepdims=True) \
+ dW_norm / self.W_norm * self.biases
dX = ddot.dot(self.coeffs.T) + dX_norm / self.X_norm * self.X
self.coeffs.g[:] = self.X.T.dot(ddot) \
+ dW_norm / self.W_norm * self.coeffs.f
self.biases.g[:] = ddot.sum(0, keepdims=True) \
+ dW_norm / self.W_norm * self.biases.f
dX = ddot.dot(self.coeffs.f.T) + dX_norm / self.X_norm * self.X
return dX
@ -817,7 +790,7 @@ def run(program, args=None):
ritual.prepare(model)
if training and config.warmup:
if training and config.warmup and not config.fn_load:
log("warming", "up")
# use plain SGD in warmup to prevent (or possibly cause?) numeric issues

View file

@ -4,8 +4,8 @@ _f = np.float32
# just for speed, not strictly essential:
from scipy.special import expit as sigmoid
# used for numbering layers like Keras:
from collections import defaultdict
# used for numbering layers like Keras, and keeping initialization consistent:
from collections import defaultdict, OrderedDict
_layer_counters = defaultdict(lambda: 0)
def _check(a):
@ -28,6 +28,12 @@ class LayerIncompatibility(Exception):
# note: these are currently only implemented for 2D shapes.
def init_zeros(size, ins=None, outs=None):
return np.zeros(size)
def init_ones(size, ins=None, outs=None):
return np.ones(size)
def init_he_normal(size, ins, outs):
s = np.sqrt(2 / ins)
return np.random.normal(0, s, size=size)
@ -264,19 +270,57 @@ class Nadam(Optimizer):
return -self.alpha * mt_bar / (np.sqrt(vtp) + self.eps)
# Weight container {{{1
class Weights:
# we may or may not contain weights -- or any information, for that matter.
def __init__(self, **kwargs):
self.f = None # forward weights
self.g = None # backward weights (gradients)
self.shape = None
self.init = None
self.allocator = None
self.configure(**kwargs)
def configure(self, **kwargs):
for k, v in kwargs.items():
getattr(self, k) # ensures the key already exists
setattr(self, k, v)
@property
def size(self):
assert self.shape is not None
return np.prod(self.shape)
def allocate(self, *args, **kwargs):
self.configure(**kwargs)
# intentionally not using isinstance
assert type(self.shape) == tuple, self.shape
f, g = self.allocator(self.size)
assert len(f) == self.size, "{} != {}".format(f.shape, self.size)
assert len(g) == self.size, "{} != {}".format(g.shape, self.size)
f[:] = self.init(self.size, *args)
g[:] = self.init(self.size, *args)
self.f = f.reshape(self.shape)
self.g = g.reshape(self.shape)
# Abstract Layers {{{1
class Layer:
def __init__(self):
self.parents = []
self.children = []
self.weights = OrderedDict()
self.input_shape = None
self.output_shape = None
kind = self.__class__.__name__
global _layer_counters
_layer_counters[kind] += 1
self.name = "{}_{}".format(kind, _layer_counters[kind])
self.size = None # total weight count (if any)
self.unsafe = False # disables assertions for better performance
def __str__(self):
@ -335,11 +379,20 @@ class Layer:
def validate_output(self, Y):
assert Y.shape[1:] == self.output_shape, (str(self), Y.shape[1:], self.output_shape)
def init(self, W, dW):
assert W.ndim == 1 and W.shape[0] == self.size, W.shape
assert dW.ndim == 1 and dW.shape[0] == self.size, dW.shape
self.W = W
self.dW = dW
def _new_weights(self, name, **kwargs):
w = Weights(**kwargs)
assert name not in self.weights, name
self.weights[name] = w
return w
@property
def size(self):
return sum((w.size for w in self.weights.values()))
def init(self, allocator):
ins, outs = self.input_shape[0], self.output_shape[0]
for k, w in self.weights.items():
w.allocate(ins, outs, allocator=allocator)
def propagate(self, values):
if not self.unsafe:
@ -407,7 +460,6 @@ class Flatten(Layer):
shape = parent.output_shape
self.input_shape = shape
self.output_shape = (np.prod(shape),)
return shape
def forward(self, X):
self.batch_size = X.shape[0]
@ -527,42 +579,25 @@ class Dense(Layer):
super().__init__()
self.dim = int(dim)
self.output_shape = (dim,)
self.weight_init = init
self.size = None
self.coeffs = self._new_weights('coeffs', init=init)
self.biases = self._new_weights('biases', init=init_zeros)
def make_shape(self, parent):
shape = parent.output_shape
self.input_shape = shape
if len(shape) != 1:
return False
self.nW = self.dim * shape[0]
self.nb = self.dim
self.size = self.nW + self.nb
return shape
def init(self, W, dW):
super().init(W, dW)
ins, outs = self.input_shape[0], self.output_shape[0]
self.coeffs = self.W[:self.nW].reshape(ins, outs)
self.biases = self.W[self.nW:].reshape(1, outs)
self.dcoeffs = self.dW[:self.nW].reshape(ins, outs)
self.dbiases = self.dW[self.nW:].reshape(1, outs)
self.coeffs.flat = self.weight_init(self.nW, ins, outs)
self.biases.flat = 0
self.std = np.std(self.W)
self.coeffs.shape = (shape[0], self.dim)
self.biases.shape = (1, self.dim)
def forward(self, X):
self.X = X
return X.dot(self.coeffs) + self.biases
return X.dot(self.coeffs.f) + self.biases.f
def backward(self, dY):
self.dcoeffs[:] = self.X.T.dot(dY)
self.dbiases[:] = dY.sum(0, keepdims=True)
return dY.dot(self.coeffs.T)
self.coeffs.g[:] = self.X.T.dot(dY)
self.biases.g[:] = dY.sum(0, keepdims=True)
return dY.dot(self.coeffs.f.T)
# Models {{{1
@ -578,18 +613,30 @@ class Model:
node.unsafe = unsafe
def make_weights(self):
self.param_count = 0
for node in self.ordered_nodes:
if node.size is not None:
self.param_count += node.size
self.param_count = sum((node.size for node in self.ordered_nodes))
self.W = np.zeros(self.param_count, dtype=_f)
self.dW = np.zeros(self.param_count, dtype=_f)
offset = 0
for node in self.ordered_nodes:
if node.size is not None:
if node.size > 0:
end = offset + node.size
node.init(self.W[offset:end], self.dW[offset:end])
inner_offset = 0
def allocate(size):
nonlocal inner_offset
o = offset + inner_offset
ret = self.W[o:o+size], self.dW[o:o+size]
inner_offset += size
assert len(ret[0]) == len(ret[1])
assert size == len(ret[0]), (size, len(ret[0]))
return ret
node.init(allocate)
assert inner_offset <= node.size, "Layer {} allocated more weights than it said it would".format(node)
# i don't care if "less" is grammatically incorrect.
# you're mom is grammatically incorrect.
assert inner_offset >= node.size, "Layer {} allocated less weights than it said it would".format(node)
offset += node.size
def traverse(self, nodes, node):
@ -638,14 +685,14 @@ class Model:
for k in weights.keys():
used[k] = False
nodes = [node for node in self.ordered_nodes if node.size is not None]
nodes = [node for node in self.ordered_nodes if node.size > 0]
for node in nodes:
full_name = str(node).lower()
for s_name, o_name in node.serialized.items():
key = full_name + '_' + s_name
data = weights[key]
target = getattr(node, o_name)
target[:] = data
target.f[:] = data
used[key] = True
for k, v in used.items():
@ -658,7 +705,7 @@ class Model:
counts = defaultdict(lambda: 0)
nodes = [node for node in self.ordered_nodes if node.size is not None]
nodes = [node for node in self.ordered_nodes if node.size > 0]
for node in nodes:
full_name = str(node).lower()
grp = f.create_group(full_name)
@ -666,7 +713,7 @@ class Model:
key = full_name + '_' + s_name
target = getattr(node, o_name)
data = grp.create_dataset(key, target.shape, dtype=_f)
data[:] = target
data[:] = target.f
counts[key] += 1
if counts[key] > 1:
lament("WARNING: rewrote weight", key)