.
This commit is contained in:
parent
028209b699
commit
3a55e9650c
1 changed files with 76 additions and 17 deletions
89
optim_nn.py
89
optim_nn.py
|
@ -347,8 +347,17 @@ class Dense(Layer):
|
|||
super().__init__()
|
||||
self.dim = ni(dim)
|
||||
self.output_shape = (dim,)
|
||||
self.size = None
|
||||
self.weight_init = init
|
||||
self.size = None
|
||||
|
||||
def make_shape(self, shape):
|
||||
super().make_shape(shape)
|
||||
if len(shape) != 1:
|
||||
return False
|
||||
self.nW = self.dim * shape[0]
|
||||
self.nb = self.dim
|
||||
self.size = self.nW + self.nb
|
||||
return shape
|
||||
|
||||
def init(self, W, dW):
|
||||
ins, outs = self.input_shape[0], self.output_shape[0]
|
||||
|
@ -363,15 +372,6 @@ class Dense(Layer):
|
|||
self.coeffs.flat = self.weight_init(self.nW, ins, outs)
|
||||
self.biases.flat = 0
|
||||
|
||||
def make_shape(self, shape):
|
||||
super().make_shape(shape)
|
||||
if len(shape) != 1:
|
||||
return False
|
||||
self.nW = self.dim * shape[0]
|
||||
self.nb = self.dim
|
||||
self.size = self.nW + self.nb
|
||||
return shape
|
||||
|
||||
def F(self, X):
|
||||
self.X = X
|
||||
Y = X.dot(self.coeffs) \
|
||||
|
@ -381,7 +381,53 @@ class Dense(Layer):
|
|||
def dF(self, dY):
|
||||
dX = dY.dot(self.coeffs.T)
|
||||
self.dcoeffs[:] = self.X.T.dot(dY)
|
||||
self.dbiases[:] = np.sum(dY, axis=0, keepdims=True)
|
||||
self.dbiases[:] = dY.sum(0, keepdims=True)
|
||||
return dX
|
||||
|
||||
class DenseOneLess(Dense):
|
||||
def init(self, W, dW):
|
||||
super().init(W, dW)
|
||||
ins, outs = self.input_shape[0], self.output_shape[0]
|
||||
assert ins == outs, (ins, outs)
|
||||
|
||||
def F(self, X):
|
||||
np.fill_diagonal(self.coeffs, 0)
|
||||
self.X = X
|
||||
Y = X.dot(self.coeffs) \
|
||||
+ self.biases
|
||||
return Y
|
||||
|
||||
def dF(self, dY):
|
||||
dX = dY.dot(self.coeffs.T)
|
||||
self.dcoeffs[:] = self.X.T.dot(dY)
|
||||
self.dbiases[:] = dY.sum(0, keepdims=True)
|
||||
np.fill_diagonal(self.dcoeffs, 0)
|
||||
return dX
|
||||
|
||||
class LayerNorm(Layer): # TODO: inherit Affine instead?
|
||||
def __init__(self, eps=1e-3, axis=-1):
|
||||
super().__init__()
|
||||
self.eps = nf(eps)
|
||||
self.axis = int(axis)
|
||||
|
||||
def F(self, X):
|
||||
self.center = X - np.mean(X, axis=self.axis, keepdims=True)
|
||||
#self.var = np.var(X, axis=self.axis, keepdims=True) + self.eps
|
||||
self.var = np.mean(np.square(self.center), axis=self.axis, keepdims=True) + self.eps
|
||||
self.std = np.sqrt(self.var) + self.eps
|
||||
Y = self.center / self.std
|
||||
return Y
|
||||
|
||||
def dF(self, dY):
|
||||
length = self.input_shape[self.axis]
|
||||
|
||||
dstd = dY * (-self.center / self.var)
|
||||
dvar = dstd * (0.5 / self.std)
|
||||
dcenter2 = dvar * (1 / length)
|
||||
dcenter = dY * (1 / self.std)
|
||||
dcenter += dcenter2 * (2 * self.center)
|
||||
dX = dcenter - dcenter / length
|
||||
|
||||
return dX
|
||||
|
||||
# Model
|
||||
|
@ -527,6 +573,8 @@ class Ritual:
|
|||
self.update(model.dW, model.W)
|
||||
|
||||
batch_loss = self.measure(residual)
|
||||
if np.isnan(batch_loss):
|
||||
raise Exception("nan")
|
||||
cumsum_loss += batch_loss
|
||||
if return_losses:
|
||||
losses.append(batch_loss)
|
||||
|
@ -542,6 +590,9 @@ def multiresnet(x, width, depth, block=2, multi=1,
|
|||
y = x
|
||||
last_size = x.output_shape[0]
|
||||
|
||||
FC = lambda size: Dense(size, init)
|
||||
#FC = lambda size: DenseOneLess(size, init)
|
||||
|
||||
for d in range(depth):
|
||||
size = width
|
||||
|
||||
|
@ -558,7 +609,7 @@ def multiresnet(x, width, depth, block=2, multi=1,
|
|||
for i in range(block):
|
||||
if i > 0:
|
||||
z = z.feed(activation())
|
||||
z = z.feed(Dense(size, init))
|
||||
z = z.feed(FC(size))
|
||||
z.feed(merger)
|
||||
y = merger
|
||||
elif style == 'onelesssum':
|
||||
|
@ -575,7 +626,7 @@ def multiresnet(x, width, depth, block=2, multi=1,
|
|||
for i in range(block):
|
||||
if i > 0:
|
||||
z = z.feed(activation())
|
||||
z = z.feed(Dense(size, init))
|
||||
z = z.feed(FC(size))
|
||||
if needs_sum:
|
||||
z.feed(merger)
|
||||
if needs_sum:
|
||||
|
@ -630,7 +681,7 @@ def run(program, args=[]):
|
|||
# misc
|
||||
batch_size = 64,
|
||||
init = 'he_normal',
|
||||
loss = SomethingElse(4/3),
|
||||
loss = SomethingElse(),
|
||||
mloss = 'mse',
|
||||
restart_optim = True, # restarts also reset internal state of optimizer
|
||||
unsafe = True, # aka gotta go fast mode
|
||||
|
@ -666,14 +717,22 @@ def run(program, args=[]):
|
|||
y = multiresnet(y,
|
||||
config.res_width, config.res_depth,
|
||||
config.res_block, config.res_multi,
|
||||
activation=activation, init=init)
|
||||
activation=activation, init=init,
|
||||
style=config.parallel_style)
|
||||
if y.output_shape[0] != output_samples:
|
||||
y = y.feed(Dense(output_samples, init))
|
||||
|
||||
model = Model(x, y, unsafe=config.unsafe)
|
||||
|
||||
if 0:
|
||||
node_names = ' '.join([str(node) for node in model.ordered_nodes])
|
||||
log('{} nodes'.format(len(model.ordered_nodes)), node_names)
|
||||
else:
|
||||
for node in model.ordered_nodes:
|
||||
children = [str(n) for n in node.children]
|
||||
if len(children) > 0:
|
||||
sep = '->'
|
||||
print(str(node)+sep+('\n'+str(node)+sep).join(children))
|
||||
log('parameters', model.param_count)
|
||||
|
||||
training = config.epochs > 0 and config.restarts >= 0
|
||||
|
|
Loading…
Reference in a new issue