.
This commit is contained in:
parent
166644023e
commit
8337fdb39c
2 changed files with 77 additions and 9 deletions
70
optim_nn.py
70
optim_nn.py
|
@ -113,6 +113,70 @@ class LayerNorm(Layer):
|
|||
|
||||
return dX
|
||||
|
||||
class Denses(Layer): # TODO: rename?
|
||||
# acts as a separate Dense for each row or column. only for 2D arrays.
|
||||
|
||||
def __init__(self, dim, init=init_he_uniform, axis=-1):
|
||||
super().__init__()
|
||||
self.dim = int(dim)
|
||||
self.weight_init = init
|
||||
self.axis = int(axis)
|
||||
self.size = None
|
||||
|
||||
def make_shape(self, shape):
|
||||
super().make_shape(shape)
|
||||
if len(shape) != 2:
|
||||
return False
|
||||
|
||||
assert -len(shape) <= self.axis < len(shape)
|
||||
self.axis = self.axis % len(shape)
|
||||
|
||||
self.output_shape = list(shape)
|
||||
self.output_shape[self.axis] = self.dim
|
||||
self.output_shape = tuple(self.output_shape)
|
||||
|
||||
self.nW = self.dim * np.prod(shape)
|
||||
self.nb = np.prod(self.output_shape)
|
||||
self.size = self.nW + self.nb
|
||||
|
||||
return shape
|
||||
|
||||
def init(self, W, dW):
|
||||
super().init(W, dW)
|
||||
|
||||
ins, outs = np.prod(self.input_shape), np.prod(self.output_shape)
|
||||
|
||||
in_rows = self.input_shape[0]
|
||||
in_cols = self.input_shape[1]
|
||||
out_rows = self.output_shape[0]
|
||||
out_cols = self.output_shape[1]
|
||||
|
||||
self.coeffs = self.W[:self.nW].reshape(in_rows, in_cols, self.dim)
|
||||
self.biases = self.W[self.nW:].reshape(1, out_rows, out_cols)
|
||||
self.dcoeffs = self.dW[:self.nW].reshape(self.coeffs.shape)
|
||||
self.dbiases = self.dW[self.nW:].reshape(self.biases.shape)
|
||||
|
||||
self.coeffs.flat = self.weight_init(self.nW, ins, outs)
|
||||
self.biases.flat = 0
|
||||
|
||||
self.std = np.std(self.W)
|
||||
|
||||
def F(self, X):
|
||||
self.X = X
|
||||
if self.axis == 0:
|
||||
return np.einsum('ixj,xjk->ikj', X, self.coeffs) + self.biases
|
||||
elif self.axis == 1:
|
||||
return np.einsum('ijx,jxk->ijk', X, self.coeffs) + self.biases
|
||||
|
||||
def dF(self, dY):
|
||||
self.dbiases[:] = dY.sum(0, keepdims=True)
|
||||
if self.axis == 0:
|
||||
self.dcoeffs[:] = np.einsum('ixj,ikj->xjk', self.X, dY)
|
||||
return np.einsum('ikj,xjk->ixj', dY, self.coeffs)
|
||||
elif self.axis == 1:
|
||||
self.dcoeffs[:] = np.einsum('ijx,ijk->jxk', self.X, dY)
|
||||
return np.einsum('ijk,jxk->ijx', dY, self.coeffs)
|
||||
|
||||
class DenseOneLess(Dense):
|
||||
def init(self, W, dW):
|
||||
super().init(W, dW)
|
||||
|
@ -122,15 +186,13 @@ class DenseOneLess(Dense):
|
|||
def F(self, X):
|
||||
np.fill_diagonal(self.coeffs, 0)
|
||||
self.X = X
|
||||
Y = X.dot(self.coeffs) + self.biases
|
||||
return Y
|
||||
return X.dot(self.coeffs) + self.biases
|
||||
|
||||
def dF(self, dY):
|
||||
dX = dY.dot(self.coeffs.T)
|
||||
self.dcoeffs[:] = self.X.T.dot(dY)
|
||||
self.dbiases[:] = dY.sum(0, keepdims=True)
|
||||
np.fill_diagonal(self.dcoeffs, 0)
|
||||
return dX
|
||||
return dY.dot(self.coeffs.T)
|
||||
|
||||
class CosineDense(Dense):
|
||||
# paper: https://arxiv.org/abs/1702.05870
|
||||
|
|
|
@ -49,7 +49,7 @@ class Loss:
|
|||
class CategoricalCrossentropy(Loss):
|
||||
# lifted from theano
|
||||
|
||||
def __init__(self, eps=1e-8):
|
||||
def __init__(self, eps=1e-6):
|
||||
self.eps = _f(eps)
|
||||
|
||||
def F(self, p, y):
|
||||
|
@ -519,14 +519,20 @@ class Dense(Layer):
|
|||
|
||||
def F(self, X):
|
||||
self.X = X
|
||||
Y = X.dot(self.coeffs) + self.biases
|
||||
return Y
|
||||
return X.dot(self.coeffs) + self.biases
|
||||
|
||||
def dF(self, dY):
|
||||
dX = dY.dot(self.coeffs.T)
|
||||
#Y = np.einsum('ix,xj->ij', X, C)
|
||||
#dX = np.einsum('ix,jx->ij', dY, C)
|
||||
#dC = np.einsum('xi,xj->ij', X, dY)
|
||||
# or rather
|
||||
#Y = np.einsum('ix,xj->ij', X, C)
|
||||
#dX = np.einsum('ij,xj->ix', dY, C)
|
||||
#dC = np.einsum('ix,ij->xj', X, dY)
|
||||
# that makes sense, just move the pairs around
|
||||
self.dcoeffs[:] = self.X.T.dot(dY)
|
||||
self.dbiases[:] = dY.sum(0, keepdims=True)
|
||||
return dX
|
||||
return dY.dot(self.coeffs.T)
|
||||
|
||||
# Models {{{1
|
||||
|
||||
|
|
Loading…
Reference in a new issue