From a78fc98215d33f9ddf299542838cb029d4f7d512 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Mon, 10 Apr 2017 09:53:54 +0000 Subject: [PATCH] greatly refactor weight handling --- optim_nn.py | 101 +++++++++++++---------------------- optim_nn_core.py | 135 ++++++++++++++++++++++++++++++++--------------- 2 files changed, 128 insertions(+), 108 deletions(-) diff --git a/optim_nn.py b/optim_nn.py index eddd8e5..9e52f8d 100755 --- a/optim_nn.py +++ b/optim_nn.py @@ -93,9 +93,10 @@ class LayerNorm(Layer): super().__init__() self.eps = _f(eps) self.affine = bool(affine) - self.size = None if self.affine: + self.gamma = self._new_weights('gamma', init=init_ones) + self.beta = self._new_weights('beta', init=init_zeros) self.serialized = { 'gamma': 'gamma', 'beta': 'beta', @@ -104,23 +105,12 @@ class LayerNorm(Layer): def make_shape(self, parent): shape = parent.output_shape self.input_shape = shape + self.output_shape = shape if len(shape) != 1: return False - self.features = shape[0] if self.affine: - self.size = 2 * self.features - return shape - - def init(self, W, dW): - super().init(W, dW) - - f = self.features - - self.gamma, self.dgamma = self.W[0*f:1*f], self.dW[0*f:1*f] - self.beta, self.dbeta = self.W[1*f:2*f], self.dW[1*f:2*f] - - self.gamma[:] = 1 - self.beta[:] = 0 + self.gamma.shape = (shape[0],) + self.beta.shape = (shape[0],) def forward(self, X): self.mean = X.mean(0) @@ -130,16 +120,16 @@ class LayerNorm(Layer): self.Xnorm = self.center / self.std if self.affine: - return self.gamma * self.Xnorm + self.beta + return self.gamma.f * self.Xnorm + self.beta.f return self.Xnorm def backward(self, dY): length = dY.shape[0] if self.affine: - dXnorm = dY * self.gamma - self.dgamma[:] = (dY * self.Xnorm).sum(0) - self.dbeta[:] = dY.sum(0) + dXnorm = dY * self.gamma.f + self.gamma.g[:] = (dY * self.Xnorm).sum(0) + self.beta.g[:] = dY.sum(0) else: dXnorm = dY @@ -163,7 +153,8 @@ class Denses(Layer): # TODO: rename? self.dim = int(dim) self.weight_init = init self.axis = int(axis) - self.size = None + self.coeffs = self._new_weights('coeffs', init=init) + self.biases = self._new_weights('biases', init=init_zeros) def make_shape(self, parent): shape = parent.output_shape @@ -178,64 +169,46 @@ class Denses(Layer): # TODO: rename? self.output_shape[self.axis] = self.dim self.output_shape = tuple(self.output_shape) - self.nW = self.dim * np.prod(shape) - self.nb = np.prod(self.output_shape) - self.size = self.nW + self.nb - - return shape - - def init(self, W, dW): - super().init(W, dW) - - ins, outs = np.prod(self.input_shape), np.prod(self.output_shape) - in_rows = self.input_shape[0] in_cols = self.input_shape[1] out_rows = self.output_shape[0] out_cols = self.output_shape[1] - self.coeffs = self.W[:self.nW].reshape(in_rows, in_cols, self.dim) - self.biases = self.W[self.nW:].reshape(1, out_rows, out_cols) - self.dcoeffs = self.dW[:self.nW].reshape(self.coeffs.shape) - self.dbiases = self.dW[self.nW:].reshape(self.biases.shape) - - self.coeffs.flat = self.weight_init(self.nW, ins, outs) - self.biases.flat = 0 - - self.std = np.std(self.W) + self.coeffs.shape = (in_rows, in_cols, self.dim) + self.biases.shape = (1, out_rows, out_cols) def forward(self, X): self.X = X if self.axis == 0: - return np.einsum('ixj,xjk->ikj', X, self.coeffs) + self.biases + return np.einsum('ixj,xjk->ikj', X, self.coeffs.f) + self.biases.f elif self.axis == 1: - return np.einsum('ijx,jxk->ijk', X, self.coeffs) + self.biases + return np.einsum('ijx,jxk->ijk', X, self.coeffs.f) + self.biases.f def backward(self, dY): - self.dbiases[:] = dY.sum(0, keepdims=True) + self.biases.g[:] = dY.sum(0, keepdims=True) if self.axis == 0: - self.dcoeffs[:] = np.einsum('ixj,ikj->xjk', self.X, dY) - return np.einsum('ikj,xjk->ixj', dY, self.coeffs) + self.coeffs.g[:] = np.einsum('ixj,ikj->xjk', self.X, dY) + return np.einsum('ikj,xjk->ixj', dY, self.coeffs.f) elif self.axis == 1: - self.dcoeffs[:] = np.einsum('ijx,ijk->jxk', self.X, dY) - return np.einsum('ijk,jxk->ijx', dY, self.coeffs) + self.coeffs.g[:] = np.einsum('ijx,ijk->jxk', self.X, dY) + return np.einsum('ijk,jxk->ijx', dY, self.coeffs.f) class DenseOneLess(Dense): - def init(self, W, dW): - super().init(W, dW) + def init(self, allocator): + super().init(allocator) ins, outs = self.input_shape[0], self.output_shape[0] assert ins == outs, (ins, outs) def forward(self, X): - np.fill_diagonal(self.coeffs, 0) + np.fill_diagonal(self.coeffs.f, 0) self.X = X - return X.dot(self.coeffs) + self.biases + return X.dot(self.coeffs.f) + self.biases def backward(self, dY): - self.dcoeffs[:] = self.X.T.dot(dY) - self.dbiases[:] = dY.sum(0, keepdims=True) - np.fill_diagonal(self.dcoeffs, 0) - return dY.dot(self.coeffs.T) + self.coeffs.g[:] = self.X.T.dot(dY) + self.biases.g[:] = dY.sum(0, keepdims=True) + np.fill_diagonal(self.coeffs.g, 0) + return dY.dot(self.coeffs.f.T) class CosineDense(Dense): # paper: https://arxiv.org/abs/1702.05870 @@ -250,9 +223,9 @@ class CosineDense(Dense): self.X = X self.X_norm = np.sqrt(np.square(X).sum(-1, keepdims=True) \ + 1 + self.eps) - self.W_norm = np.sqrt(np.square(self.coeffs).sum(0, keepdims=True) \ - + np.square(self.biases) + self.eps) - self.dot = X.dot(self.coeffs) + self.biases + self.W_norm = np.sqrt(np.square(self.coeffs.f).sum(0, keepdims=True) \ + + np.square(self.biases.f) + self.eps) + self.dot = X.dot(self.coeffs.f) + self.biases.f Y = self.dot / (self.X_norm * self.W_norm) return Y @@ -261,11 +234,11 @@ class CosineDense(Dense): dX_norm = -(dY * self.dot / self.W_norm).sum(-1, keepdims=True) / self.X_norm**2 dW_norm = -(dY * self.dot / self.X_norm).sum( 0, keepdims=True) / self.W_norm**2 - self.dcoeffs[:] = self.X.T.dot(ddot) \ - + dW_norm / self.W_norm * self.coeffs - self.dbiases[:] = ddot.sum(0, keepdims=True) \ - + dW_norm / self.W_norm * self.biases - dX = ddot.dot(self.coeffs.T) + dX_norm / self.X_norm * self.X + self.coeffs.g[:] = self.X.T.dot(ddot) \ + + dW_norm / self.W_norm * self.coeffs.f + self.biases.g[:] = ddot.sum(0, keepdims=True) \ + + dW_norm / self.W_norm * self.biases.f + dX = ddot.dot(self.coeffs.f.T) + dX_norm / self.X_norm * self.X return dX @@ -817,7 +790,7 @@ def run(program, args=None): ritual.prepare(model) - if training and config.warmup: + if training and config.warmup and not config.fn_load: log("warming", "up") # use plain SGD in warmup to prevent (or possibly cause?) numeric issues diff --git a/optim_nn_core.py b/optim_nn_core.py index f39ee60..8cf9bb9 100644 --- a/optim_nn_core.py +++ b/optim_nn_core.py @@ -4,8 +4,8 @@ _f = np.float32 # just for speed, not strictly essential: from scipy.special import expit as sigmoid -# used for numbering layers like Keras: -from collections import defaultdict +# used for numbering layers like Keras, and keeping initialization consistent: +from collections import defaultdict, OrderedDict _layer_counters = defaultdict(lambda: 0) def _check(a): @@ -28,6 +28,12 @@ class LayerIncompatibility(Exception): # note: these are currently only implemented for 2D shapes. +def init_zeros(size, ins=None, outs=None): + return np.zeros(size) + +def init_ones(size, ins=None, outs=None): + return np.ones(size) + def init_he_normal(size, ins, outs): s = np.sqrt(2 / ins) return np.random.normal(0, s, size=size) @@ -264,19 +270,57 @@ class Nadam(Optimizer): return -self.alpha * mt_bar / (np.sqrt(vtp) + self.eps) +# Weight container {{{1 + +class Weights: + # we may or may not contain weights -- or any information, for that matter. + + def __init__(self, **kwargs): + self.f = None # forward weights + self.g = None # backward weights (gradients) + self.shape = None + self.init = None + self.allocator = None + + self.configure(**kwargs) + + def configure(self, **kwargs): + for k, v in kwargs.items(): + getattr(self, k) # ensures the key already exists + setattr(self, k, v) + + @property + def size(self): + assert self.shape is not None + return np.prod(self.shape) + + def allocate(self, *args, **kwargs): + self.configure(**kwargs) + + # intentionally not using isinstance + assert type(self.shape) == tuple, self.shape + + f, g = self.allocator(self.size) + assert len(f) == self.size, "{} != {}".format(f.shape, self.size) + assert len(g) == self.size, "{} != {}".format(g.shape, self.size) + f[:] = self.init(self.size, *args) + g[:] = self.init(self.size, *args) + self.f = f.reshape(self.shape) + self.g = g.reshape(self.shape) + # Abstract Layers {{{1 class Layer: def __init__(self): self.parents = [] self.children = [] + self.weights = OrderedDict() self.input_shape = None self.output_shape = None kind = self.__class__.__name__ global _layer_counters _layer_counters[kind] += 1 self.name = "{}_{}".format(kind, _layer_counters[kind]) - self.size = None # total weight count (if any) self.unsafe = False # disables assertions for better performance def __str__(self): @@ -335,11 +379,20 @@ class Layer: def validate_output(self, Y): assert Y.shape[1:] == self.output_shape, (str(self), Y.shape[1:], self.output_shape) - def init(self, W, dW): - assert W.ndim == 1 and W.shape[0] == self.size, W.shape - assert dW.ndim == 1 and dW.shape[0] == self.size, dW.shape - self.W = W - self.dW = dW + def _new_weights(self, name, **kwargs): + w = Weights(**kwargs) + assert name not in self.weights, name + self.weights[name] = w + return w + + @property + def size(self): + return sum((w.size for w in self.weights.values())) + + def init(self, allocator): + ins, outs = self.input_shape[0], self.output_shape[0] + for k, w in self.weights.items(): + w.allocate(ins, outs, allocator=allocator) def propagate(self, values): if not self.unsafe: @@ -407,7 +460,6 @@ class Flatten(Layer): shape = parent.output_shape self.input_shape = shape self.output_shape = (np.prod(shape),) - return shape def forward(self, X): self.batch_size = X.shape[0] @@ -527,42 +579,25 @@ class Dense(Layer): super().__init__() self.dim = int(dim) self.output_shape = (dim,) - self.weight_init = init - self.size = None + self.coeffs = self._new_weights('coeffs', init=init) + self.biases = self._new_weights('biases', init=init_zeros) def make_shape(self, parent): shape = parent.output_shape self.input_shape = shape if len(shape) != 1: return False - self.nW = self.dim * shape[0] - self.nb = self.dim - self.size = self.nW + self.nb - return shape - - def init(self, W, dW): - super().init(W, dW) - - ins, outs = self.input_shape[0], self.output_shape[0] - - self.coeffs = self.W[:self.nW].reshape(ins, outs) - self.biases = self.W[self.nW:].reshape(1, outs) - self.dcoeffs = self.dW[:self.nW].reshape(ins, outs) - self.dbiases = self.dW[self.nW:].reshape(1, outs) - - self.coeffs.flat = self.weight_init(self.nW, ins, outs) - self.biases.flat = 0 - - self.std = np.std(self.W) + self.coeffs.shape = (shape[0], self.dim) + self.biases.shape = (1, self.dim) def forward(self, X): self.X = X - return X.dot(self.coeffs) + self.biases + return X.dot(self.coeffs.f) + self.biases.f def backward(self, dY): - self.dcoeffs[:] = self.X.T.dot(dY) - self.dbiases[:] = dY.sum(0, keepdims=True) - return dY.dot(self.coeffs.T) + self.coeffs.g[:] = self.X.T.dot(dY) + self.biases.g[:] = dY.sum(0, keepdims=True) + return dY.dot(self.coeffs.f.T) # Models {{{1 @@ -578,18 +613,30 @@ class Model: node.unsafe = unsafe def make_weights(self): - self.param_count = 0 - for node in self.ordered_nodes: - if node.size is not None: - self.param_count += node.size + self.param_count = sum((node.size for node in self.ordered_nodes)) self.W = np.zeros(self.param_count, dtype=_f) self.dW = np.zeros(self.param_count, dtype=_f) offset = 0 for node in self.ordered_nodes: - if node.size is not None: + if node.size > 0: end = offset + node.size - node.init(self.W[offset:end], self.dW[offset:end]) + inner_offset = 0 + + def allocate(size): + nonlocal inner_offset + o = offset + inner_offset + ret = self.W[o:o+size], self.dW[o:o+size] + inner_offset += size + assert len(ret[0]) == len(ret[1]) + assert size == len(ret[0]), (size, len(ret[0])) + return ret + + node.init(allocate) + assert inner_offset <= node.size, "Layer {} allocated more weights than it said it would".format(node) + # i don't care if "less" is grammatically incorrect. + # you're mom is grammatically incorrect. + assert inner_offset >= node.size, "Layer {} allocated less weights than it said it would".format(node) offset += node.size def traverse(self, nodes, node): @@ -638,14 +685,14 @@ class Model: for k in weights.keys(): used[k] = False - nodes = [node for node in self.ordered_nodes if node.size is not None] + nodes = [node for node in self.ordered_nodes if node.size > 0] for node in nodes: full_name = str(node).lower() for s_name, o_name in node.serialized.items(): key = full_name + '_' + s_name data = weights[key] target = getattr(node, o_name) - target[:] = data + target.f[:] = data used[key] = True for k, v in used.items(): @@ -658,7 +705,7 @@ class Model: counts = defaultdict(lambda: 0) - nodes = [node for node in self.ordered_nodes if node.size is not None] + nodes = [node for node in self.ordered_nodes if node.size > 0] for node in nodes: full_name = str(node).lower() grp = f.create_group(full_name) @@ -666,7 +713,7 @@ class Model: key = full_name + '_' + s_name target = getattr(node, o_name) data = grp.create_dataset(key, target.shape, dtype=_f) - data[:] = target + data[:] = target.f counts[key] += 1 if counts[key] > 1: lament("WARNING: rewrote weight", key)