diff --git a/optim_nn.py b/optim_nn.py index 2d3c478..bcb23d8 100644 --- a/optim_nn.py +++ b/optim_nn.py @@ -347,8 +347,17 @@ class Dense(Layer): super().__init__() self.dim = ni(dim) self.output_shape = (dim,) - self.size = None self.weight_init = init + self.size = None + + def make_shape(self, shape): + super().make_shape(shape) + if len(shape) != 1: + return False + self.nW = self.dim * shape[0] + self.nb = self.dim + self.size = self.nW + self.nb + return shape def init(self, W, dW): ins, outs = self.input_shape[0], self.output_shape[0] @@ -363,15 +372,6 @@ class Dense(Layer): self.coeffs.flat = self.weight_init(self.nW, ins, outs) self.biases.flat = 0 - def make_shape(self, shape): - super().make_shape(shape) - if len(shape) != 1: - return False - self.nW = self.dim * shape[0] - self.nb = self.dim - self.size = self.nW + self.nb - return shape - def F(self, X): self.X = X Y = X.dot(self.coeffs) \ @@ -381,7 +381,53 @@ class Dense(Layer): def dF(self, dY): dX = dY.dot(self.coeffs.T) self.dcoeffs[:] = self.X.T.dot(dY) - self.dbiases[:] = np.sum(dY, axis=0, keepdims=True) + self.dbiases[:] = dY.sum(0, keepdims=True) + return dX + +class DenseOneLess(Dense): + def init(self, W, dW): + super().init(W, dW) + ins, outs = self.input_shape[0], self.output_shape[0] + assert ins == outs, (ins, outs) + + def F(self, X): + np.fill_diagonal(self.coeffs, 0) + self.X = X + Y = X.dot(self.coeffs) \ + + self.biases + return Y + + def dF(self, dY): + dX = dY.dot(self.coeffs.T) + self.dcoeffs[:] = self.X.T.dot(dY) + self.dbiases[:] = dY.sum(0, keepdims=True) + np.fill_diagonal(self.dcoeffs, 0) + return dX + +class LayerNorm(Layer): # TODO: inherit Affine instead? + def __init__(self, eps=1e-3, axis=-1): + super().__init__() + self.eps = nf(eps) + self.axis = int(axis) + + def F(self, X): + self.center = X - np.mean(X, axis=self.axis, keepdims=True) + #self.var = np.var(X, axis=self.axis, keepdims=True) + self.eps + self.var = np.mean(np.square(self.center), axis=self.axis, keepdims=True) + self.eps + self.std = np.sqrt(self.var) + self.eps + Y = self.center / self.std + return Y + + def dF(self, dY): + length = self.input_shape[self.axis] + + dstd = dY * (-self.center / self.var) + dvar = dstd * (0.5 / self.std) + dcenter2 = dvar * (1 / length) + dcenter = dY * (1 / self.std) + dcenter += dcenter2 * (2 * self.center) + dX = dcenter - dcenter / length + return dX # Model @@ -527,6 +573,8 @@ class Ritual: self.update(model.dW, model.W) batch_loss = self.measure(residual) + if np.isnan(batch_loss): + raise Exception("nan") cumsum_loss += batch_loss if return_losses: losses.append(batch_loss) @@ -542,6 +590,9 @@ def multiresnet(x, width, depth, block=2, multi=1, y = x last_size = x.output_shape[0] + FC = lambda size: Dense(size, init) + #FC = lambda size: DenseOneLess(size, init) + for d in range(depth): size = width @@ -558,7 +609,7 @@ def multiresnet(x, width, depth, block=2, multi=1, for i in range(block): if i > 0: z = z.feed(activation()) - z = z.feed(Dense(size, init)) + z = z.feed(FC(size)) z.feed(merger) y = merger elif style == 'onelesssum': @@ -575,7 +626,7 @@ def multiresnet(x, width, depth, block=2, multi=1, for i in range(block): if i > 0: z = z.feed(activation()) - z = z.feed(Dense(size, init)) + z = z.feed(FC(size)) if needs_sum: z.feed(merger) if needs_sum: @@ -630,7 +681,7 @@ def run(program, args=[]): # misc batch_size = 64, init = 'he_normal', - loss = SomethingElse(4/3), + loss = SomethingElse(), mloss = 'mse', restart_optim = True, # restarts also reset internal state of optimizer unsafe = True, # aka gotta go fast mode @@ -666,14 +717,22 @@ def run(program, args=[]): y = multiresnet(y, config.res_width, config.res_depth, config.res_block, config.res_multi, - activation=activation, init=init) + activation=activation, init=init, + style=config.parallel_style) if y.output_shape[0] != output_samples: y = y.feed(Dense(output_samples, init)) model = Model(x, y, unsafe=config.unsafe) - node_names = ' '.join([str(node) for node in model.ordered_nodes]) - log('{} nodes'.format(len(model.ordered_nodes)), node_names) + if 0: + node_names = ' '.join([str(node) for node in model.ordered_nodes]) + log('{} nodes'.format(len(model.ordered_nodes)), node_names) + else: + for node in model.ordered_nodes: + children = [str(n) for n in node.children] + if len(children) > 0: + sep = '->' + print(str(node)+sep+('\n'+str(node)+sep).join(children)) log('parameters', model.param_count) training = config.epochs > 0 and config.restarts >= 0