greatly refactor weight handling
This commit is contained in:
parent
c3e2dd56bf
commit
a78fc98215
2 changed files with 128 additions and 108 deletions
101
optim_nn.py
101
optim_nn.py
|
@ -93,9 +93,10 @@ class LayerNorm(Layer):
|
|||
super().__init__()
|
||||
self.eps = _f(eps)
|
||||
self.affine = bool(affine)
|
||||
self.size = None
|
||||
|
||||
if self.affine:
|
||||
self.gamma = self._new_weights('gamma', init=init_ones)
|
||||
self.beta = self._new_weights('beta', init=init_zeros)
|
||||
self.serialized = {
|
||||
'gamma': 'gamma',
|
||||
'beta': 'beta',
|
||||
|
@ -104,23 +105,12 @@ class LayerNorm(Layer):
|
|||
def make_shape(self, parent):
|
||||
shape = parent.output_shape
|
||||
self.input_shape = shape
|
||||
self.output_shape = shape
|
||||
if len(shape) != 1:
|
||||
return False
|
||||
self.features = shape[0]
|
||||
if self.affine:
|
||||
self.size = 2 * self.features
|
||||
return shape
|
||||
|
||||
def init(self, W, dW):
|
||||
super().init(W, dW)
|
||||
|
||||
f = self.features
|
||||
|
||||
self.gamma, self.dgamma = self.W[0*f:1*f], self.dW[0*f:1*f]
|
||||
self.beta, self.dbeta = self.W[1*f:2*f], self.dW[1*f:2*f]
|
||||
|
||||
self.gamma[:] = 1
|
||||
self.beta[:] = 0
|
||||
self.gamma.shape = (shape[0],)
|
||||
self.beta.shape = (shape[0],)
|
||||
|
||||
def forward(self, X):
|
||||
self.mean = X.mean(0)
|
||||
|
@ -130,16 +120,16 @@ class LayerNorm(Layer):
|
|||
|
||||
self.Xnorm = self.center / self.std
|
||||
if self.affine:
|
||||
return self.gamma * self.Xnorm + self.beta
|
||||
return self.gamma.f * self.Xnorm + self.beta.f
|
||||
return self.Xnorm
|
||||
|
||||
def backward(self, dY):
|
||||
length = dY.shape[0]
|
||||
|
||||
if self.affine:
|
||||
dXnorm = dY * self.gamma
|
||||
self.dgamma[:] = (dY * self.Xnorm).sum(0)
|
||||
self.dbeta[:] = dY.sum(0)
|
||||
dXnorm = dY * self.gamma.f
|
||||
self.gamma.g[:] = (dY * self.Xnorm).sum(0)
|
||||
self.beta.g[:] = dY.sum(0)
|
||||
else:
|
||||
dXnorm = dY
|
||||
|
||||
|
@ -163,7 +153,8 @@ class Denses(Layer): # TODO: rename?
|
|||
self.dim = int(dim)
|
||||
self.weight_init = init
|
||||
self.axis = int(axis)
|
||||
self.size = None
|
||||
self.coeffs = self._new_weights('coeffs', init=init)
|
||||
self.biases = self._new_weights('biases', init=init_zeros)
|
||||
|
||||
def make_shape(self, parent):
|
||||
shape = parent.output_shape
|
||||
|
@ -178,64 +169,46 @@ class Denses(Layer): # TODO: rename?
|
|||
self.output_shape[self.axis] = self.dim
|
||||
self.output_shape = tuple(self.output_shape)
|
||||
|
||||
self.nW = self.dim * np.prod(shape)
|
||||
self.nb = np.prod(self.output_shape)
|
||||
self.size = self.nW + self.nb
|
||||
|
||||
return shape
|
||||
|
||||
def init(self, W, dW):
|
||||
super().init(W, dW)
|
||||
|
||||
ins, outs = np.prod(self.input_shape), np.prod(self.output_shape)
|
||||
|
||||
in_rows = self.input_shape[0]
|
||||
in_cols = self.input_shape[1]
|
||||
out_rows = self.output_shape[0]
|
||||
out_cols = self.output_shape[1]
|
||||
|
||||
self.coeffs = self.W[:self.nW].reshape(in_rows, in_cols, self.dim)
|
||||
self.biases = self.W[self.nW:].reshape(1, out_rows, out_cols)
|
||||
self.dcoeffs = self.dW[:self.nW].reshape(self.coeffs.shape)
|
||||
self.dbiases = self.dW[self.nW:].reshape(self.biases.shape)
|
||||
|
||||
self.coeffs.flat = self.weight_init(self.nW, ins, outs)
|
||||
self.biases.flat = 0
|
||||
|
||||
self.std = np.std(self.W)
|
||||
self.coeffs.shape = (in_rows, in_cols, self.dim)
|
||||
self.biases.shape = (1, out_rows, out_cols)
|
||||
|
||||
def forward(self, X):
|
||||
self.X = X
|
||||
if self.axis == 0:
|
||||
return np.einsum('ixj,xjk->ikj', X, self.coeffs) + self.biases
|
||||
return np.einsum('ixj,xjk->ikj', X, self.coeffs.f) + self.biases.f
|
||||
elif self.axis == 1:
|
||||
return np.einsum('ijx,jxk->ijk', X, self.coeffs) + self.biases
|
||||
return np.einsum('ijx,jxk->ijk', X, self.coeffs.f) + self.biases.f
|
||||
|
||||
def backward(self, dY):
|
||||
self.dbiases[:] = dY.sum(0, keepdims=True)
|
||||
self.biases.g[:] = dY.sum(0, keepdims=True)
|
||||
if self.axis == 0:
|
||||
self.dcoeffs[:] = np.einsum('ixj,ikj->xjk', self.X, dY)
|
||||
return np.einsum('ikj,xjk->ixj', dY, self.coeffs)
|
||||
self.coeffs.g[:] = np.einsum('ixj,ikj->xjk', self.X, dY)
|
||||
return np.einsum('ikj,xjk->ixj', dY, self.coeffs.f)
|
||||
elif self.axis == 1:
|
||||
self.dcoeffs[:] = np.einsum('ijx,ijk->jxk', self.X, dY)
|
||||
return np.einsum('ijk,jxk->ijx', dY, self.coeffs)
|
||||
self.coeffs.g[:] = np.einsum('ijx,ijk->jxk', self.X, dY)
|
||||
return np.einsum('ijk,jxk->ijx', dY, self.coeffs.f)
|
||||
|
||||
class DenseOneLess(Dense):
|
||||
def init(self, W, dW):
|
||||
super().init(W, dW)
|
||||
def init(self, allocator):
|
||||
super().init(allocator)
|
||||
ins, outs = self.input_shape[0], self.output_shape[0]
|
||||
assert ins == outs, (ins, outs)
|
||||
|
||||
def forward(self, X):
|
||||
np.fill_diagonal(self.coeffs, 0)
|
||||
np.fill_diagonal(self.coeffs.f, 0)
|
||||
self.X = X
|
||||
return X.dot(self.coeffs) + self.biases
|
||||
return X.dot(self.coeffs.f) + self.biases
|
||||
|
||||
def backward(self, dY):
|
||||
self.dcoeffs[:] = self.X.T.dot(dY)
|
||||
self.dbiases[:] = dY.sum(0, keepdims=True)
|
||||
np.fill_diagonal(self.dcoeffs, 0)
|
||||
return dY.dot(self.coeffs.T)
|
||||
self.coeffs.g[:] = self.X.T.dot(dY)
|
||||
self.biases.g[:] = dY.sum(0, keepdims=True)
|
||||
np.fill_diagonal(self.coeffs.g, 0)
|
||||
return dY.dot(self.coeffs.f.T)
|
||||
|
||||
class CosineDense(Dense):
|
||||
# paper: https://arxiv.org/abs/1702.05870
|
||||
|
@ -250,9 +223,9 @@ class CosineDense(Dense):
|
|||
self.X = X
|
||||
self.X_norm = np.sqrt(np.square(X).sum(-1, keepdims=True) \
|
||||
+ 1 + self.eps)
|
||||
self.W_norm = np.sqrt(np.square(self.coeffs).sum(0, keepdims=True) \
|
||||
+ np.square(self.biases) + self.eps)
|
||||
self.dot = X.dot(self.coeffs) + self.biases
|
||||
self.W_norm = np.sqrt(np.square(self.coeffs.f).sum(0, keepdims=True) \
|
||||
+ np.square(self.biases.f) + self.eps)
|
||||
self.dot = X.dot(self.coeffs.f) + self.biases.f
|
||||
Y = self.dot / (self.X_norm * self.W_norm)
|
||||
return Y
|
||||
|
||||
|
@ -261,11 +234,11 @@ class CosineDense(Dense):
|
|||
dX_norm = -(dY * self.dot / self.W_norm).sum(-1, keepdims=True) / self.X_norm**2
|
||||
dW_norm = -(dY * self.dot / self.X_norm).sum( 0, keepdims=True) / self.W_norm**2
|
||||
|
||||
self.dcoeffs[:] = self.X.T.dot(ddot) \
|
||||
+ dW_norm / self.W_norm * self.coeffs
|
||||
self.dbiases[:] = ddot.sum(0, keepdims=True) \
|
||||
+ dW_norm / self.W_norm * self.biases
|
||||
dX = ddot.dot(self.coeffs.T) + dX_norm / self.X_norm * self.X
|
||||
self.coeffs.g[:] = self.X.T.dot(ddot) \
|
||||
+ dW_norm / self.W_norm * self.coeffs.f
|
||||
self.biases.g[:] = ddot.sum(0, keepdims=True) \
|
||||
+ dW_norm / self.W_norm * self.biases.f
|
||||
dX = ddot.dot(self.coeffs.f.T) + dX_norm / self.X_norm * self.X
|
||||
|
||||
return dX
|
||||
|
||||
|
@ -817,7 +790,7 @@ def run(program, args=None):
|
|||
|
||||
ritual.prepare(model)
|
||||
|
||||
if training and config.warmup:
|
||||
if training and config.warmup and not config.fn_load:
|
||||
log("warming", "up")
|
||||
|
||||
# use plain SGD in warmup to prevent (or possibly cause?) numeric issues
|
||||
|
|
135
optim_nn_core.py
135
optim_nn_core.py
|
@ -4,8 +4,8 @@ _f = np.float32
|
|||
# just for speed, not strictly essential:
|
||||
from scipy.special import expit as sigmoid
|
||||
|
||||
# used for numbering layers like Keras:
|
||||
from collections import defaultdict
|
||||
# used for numbering layers like Keras, and keeping initialization consistent:
|
||||
from collections import defaultdict, OrderedDict
|
||||
_layer_counters = defaultdict(lambda: 0)
|
||||
|
||||
def _check(a):
|
||||
|
@ -28,6 +28,12 @@ class LayerIncompatibility(Exception):
|
|||
|
||||
# note: these are currently only implemented for 2D shapes.
|
||||
|
||||
def init_zeros(size, ins=None, outs=None):
|
||||
return np.zeros(size)
|
||||
|
||||
def init_ones(size, ins=None, outs=None):
|
||||
return np.ones(size)
|
||||
|
||||
def init_he_normal(size, ins, outs):
|
||||
s = np.sqrt(2 / ins)
|
||||
return np.random.normal(0, s, size=size)
|
||||
|
@ -264,19 +270,57 @@ class Nadam(Optimizer):
|
|||
|
||||
return -self.alpha * mt_bar / (np.sqrt(vtp) + self.eps)
|
||||
|
||||
# Weight container {{{1
|
||||
|
||||
class Weights:
|
||||
# we may or may not contain weights -- or any information, for that matter.
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.f = None # forward weights
|
||||
self.g = None # backward weights (gradients)
|
||||
self.shape = None
|
||||
self.init = None
|
||||
self.allocator = None
|
||||
|
||||
self.configure(**kwargs)
|
||||
|
||||
def configure(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
getattr(self, k) # ensures the key already exists
|
||||
setattr(self, k, v)
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
assert self.shape is not None
|
||||
return np.prod(self.shape)
|
||||
|
||||
def allocate(self, *args, **kwargs):
|
||||
self.configure(**kwargs)
|
||||
|
||||
# intentionally not using isinstance
|
||||
assert type(self.shape) == tuple, self.shape
|
||||
|
||||
f, g = self.allocator(self.size)
|
||||
assert len(f) == self.size, "{} != {}".format(f.shape, self.size)
|
||||
assert len(g) == self.size, "{} != {}".format(g.shape, self.size)
|
||||
f[:] = self.init(self.size, *args)
|
||||
g[:] = self.init(self.size, *args)
|
||||
self.f = f.reshape(self.shape)
|
||||
self.g = g.reshape(self.shape)
|
||||
|
||||
# Abstract Layers {{{1
|
||||
|
||||
class Layer:
|
||||
def __init__(self):
|
||||
self.parents = []
|
||||
self.children = []
|
||||
self.weights = OrderedDict()
|
||||
self.input_shape = None
|
||||
self.output_shape = None
|
||||
kind = self.__class__.__name__
|
||||
global _layer_counters
|
||||
_layer_counters[kind] += 1
|
||||
self.name = "{}_{}".format(kind, _layer_counters[kind])
|
||||
self.size = None # total weight count (if any)
|
||||
self.unsafe = False # disables assertions for better performance
|
||||
|
||||
def __str__(self):
|
||||
|
@ -335,11 +379,20 @@ class Layer:
|
|||
def validate_output(self, Y):
|
||||
assert Y.shape[1:] == self.output_shape, (str(self), Y.shape[1:], self.output_shape)
|
||||
|
||||
def init(self, W, dW):
|
||||
assert W.ndim == 1 and W.shape[0] == self.size, W.shape
|
||||
assert dW.ndim == 1 and dW.shape[0] == self.size, dW.shape
|
||||
self.W = W
|
||||
self.dW = dW
|
||||
def _new_weights(self, name, **kwargs):
|
||||
w = Weights(**kwargs)
|
||||
assert name not in self.weights, name
|
||||
self.weights[name] = w
|
||||
return w
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return sum((w.size for w in self.weights.values()))
|
||||
|
||||
def init(self, allocator):
|
||||
ins, outs = self.input_shape[0], self.output_shape[0]
|
||||
for k, w in self.weights.items():
|
||||
w.allocate(ins, outs, allocator=allocator)
|
||||
|
||||
def propagate(self, values):
|
||||
if not self.unsafe:
|
||||
|
@ -407,7 +460,6 @@ class Flatten(Layer):
|
|||
shape = parent.output_shape
|
||||
self.input_shape = shape
|
||||
self.output_shape = (np.prod(shape),)
|
||||
return shape
|
||||
|
||||
def forward(self, X):
|
||||
self.batch_size = X.shape[0]
|
||||
|
@ -527,42 +579,25 @@ class Dense(Layer):
|
|||
super().__init__()
|
||||
self.dim = int(dim)
|
||||
self.output_shape = (dim,)
|
||||
self.weight_init = init
|
||||
self.size = None
|
||||
self.coeffs = self._new_weights('coeffs', init=init)
|
||||
self.biases = self._new_weights('biases', init=init_zeros)
|
||||
|
||||
def make_shape(self, parent):
|
||||
shape = parent.output_shape
|
||||
self.input_shape = shape
|
||||
if len(shape) != 1:
|
||||
return False
|
||||
self.nW = self.dim * shape[0]
|
||||
self.nb = self.dim
|
||||
self.size = self.nW + self.nb
|
||||
return shape
|
||||
|
||||
def init(self, W, dW):
|
||||
super().init(W, dW)
|
||||
|
||||
ins, outs = self.input_shape[0], self.output_shape[0]
|
||||
|
||||
self.coeffs = self.W[:self.nW].reshape(ins, outs)
|
||||
self.biases = self.W[self.nW:].reshape(1, outs)
|
||||
self.dcoeffs = self.dW[:self.nW].reshape(ins, outs)
|
||||
self.dbiases = self.dW[self.nW:].reshape(1, outs)
|
||||
|
||||
self.coeffs.flat = self.weight_init(self.nW, ins, outs)
|
||||
self.biases.flat = 0
|
||||
|
||||
self.std = np.std(self.W)
|
||||
self.coeffs.shape = (shape[0], self.dim)
|
||||
self.biases.shape = (1, self.dim)
|
||||
|
||||
def forward(self, X):
|
||||
self.X = X
|
||||
return X.dot(self.coeffs) + self.biases
|
||||
return X.dot(self.coeffs.f) + self.biases.f
|
||||
|
||||
def backward(self, dY):
|
||||
self.dcoeffs[:] = self.X.T.dot(dY)
|
||||
self.dbiases[:] = dY.sum(0, keepdims=True)
|
||||
return dY.dot(self.coeffs.T)
|
||||
self.coeffs.g[:] = self.X.T.dot(dY)
|
||||
self.biases.g[:] = dY.sum(0, keepdims=True)
|
||||
return dY.dot(self.coeffs.f.T)
|
||||
|
||||
# Models {{{1
|
||||
|
||||
|
@ -578,18 +613,30 @@ class Model:
|
|||
node.unsafe = unsafe
|
||||
|
||||
def make_weights(self):
|
||||
self.param_count = 0
|
||||
for node in self.ordered_nodes:
|
||||
if node.size is not None:
|
||||
self.param_count += node.size
|
||||
self.param_count = sum((node.size for node in self.ordered_nodes))
|
||||
self.W = np.zeros(self.param_count, dtype=_f)
|
||||
self.dW = np.zeros(self.param_count, dtype=_f)
|
||||
|
||||
offset = 0
|
||||
for node in self.ordered_nodes:
|
||||
if node.size is not None:
|
||||
if node.size > 0:
|
||||
end = offset + node.size
|
||||
node.init(self.W[offset:end], self.dW[offset:end])
|
||||
inner_offset = 0
|
||||
|
||||
def allocate(size):
|
||||
nonlocal inner_offset
|
||||
o = offset + inner_offset
|
||||
ret = self.W[o:o+size], self.dW[o:o+size]
|
||||
inner_offset += size
|
||||
assert len(ret[0]) == len(ret[1])
|
||||
assert size == len(ret[0]), (size, len(ret[0]))
|
||||
return ret
|
||||
|
||||
node.init(allocate)
|
||||
assert inner_offset <= node.size, "Layer {} allocated more weights than it said it would".format(node)
|
||||
# i don't care if "less" is grammatically incorrect.
|
||||
# you're mom is grammatically incorrect.
|
||||
assert inner_offset >= node.size, "Layer {} allocated less weights than it said it would".format(node)
|
||||
offset += node.size
|
||||
|
||||
def traverse(self, nodes, node):
|
||||
|
@ -638,14 +685,14 @@ class Model:
|
|||
for k in weights.keys():
|
||||
used[k] = False
|
||||
|
||||
nodes = [node for node in self.ordered_nodes if node.size is not None]
|
||||
nodes = [node for node in self.ordered_nodes if node.size > 0]
|
||||
for node in nodes:
|
||||
full_name = str(node).lower()
|
||||
for s_name, o_name in node.serialized.items():
|
||||
key = full_name + '_' + s_name
|
||||
data = weights[key]
|
||||
target = getattr(node, o_name)
|
||||
target[:] = data
|
||||
target.f[:] = data
|
||||
used[key] = True
|
||||
|
||||
for k, v in used.items():
|
||||
|
@ -658,7 +705,7 @@ class Model:
|
|||
|
||||
counts = defaultdict(lambda: 0)
|
||||
|
||||
nodes = [node for node in self.ordered_nodes if node.size is not None]
|
||||
nodes = [node for node in self.ordered_nodes if node.size > 0]
|
||||
for node in nodes:
|
||||
full_name = str(node).lower()
|
||||
grp = f.create_group(full_name)
|
||||
|
@ -666,7 +713,7 @@ class Model:
|
|||
key = full_name + '_' + s_name
|
||||
target = getattr(node, o_name)
|
||||
data = grp.create_dataset(key, target.shape, dtype=_f)
|
||||
data[:] = target
|
||||
data[:] = target.f
|
||||
counts[key] += 1
|
||||
if counts[key] > 1:
|
||||
lament("WARNING: rewrote weight", key)
|
||||
|
|
Loading…
Reference in a new issue