diff --git a/optim_nn.py b/optim_nn.py index 16feca9..357b0fc 100644 --- a/optim_nn.py +++ b/optim_nn.py @@ -6,8 +6,22 @@ nfa = lambda x: np.array(x, dtype=nf) ni = np.int nia = lambda x: np.array(x, dtype=ni) +from scipy.special import expit as sigmoid + from collections import defaultdict +# Initializations + +# note: these are currently only implemented for 2D shapes. + +def init_he_normal(size, ins, outs): + s = np.sqrt(2 / ins) + return np.random.normal(0, s, size=size) + +def init_he_uniform(size, ins, outs): + s = np.sqrt(6 / ins) + return np.random.uniform(-s, s, size=size) + # Loss functions class Loss: @@ -32,6 +46,21 @@ class SquaredHalved(Loss): def df(self, r): return r +class SomethingElse(Loss): + # generalizes Absolute and SquaredHalved + # plot: https://www.desmos.com/calculator/fagjg9vuz7 + def __init__(self, a=4/3): + assert 1 <= a <= 2, "parameter out of range" + self.a = nf(a / 2) + self.b = nf(2 / a) + self.c = nf(2 / a - 1) + + def f(self, r): + return self.a * np.abs(r)**self.b + + def df(self, r): + return np.sign(r) * np.abs(r)**self.c + # Optimizers class Optimizer: @@ -102,8 +131,8 @@ class Adam(Optimizer): self.b1_t *= self.b1 self.b2_t *= self.b2 - self.mt = self.b1 * self.mt + (1 - self.b1) * dW - self.vt = self.b2 * self.vt + (1 - self.b2) * dW * dW + self.mt[:] = self.b1 * self.mt + (1 - self.b1) * dW + self.vt[:] = self.b2 * self.vt + (1 - self.b2) * dW * dW return -self.alpha * (self.mt / (1 - self.b1_t)) \ / np.sqrt((self.vt / (1 - self.b2_t)) + self.eps) @@ -306,7 +335,6 @@ class GeluApprox(Layer): # paper: https://arxiv.org/abs/1606.08415 # plot: https://www.desmos.com/calculator/ydzgtccsld def F(self, X): - from scipy.special import expit as sigmoid self.a = 1.704 * X self.sig = sigmoid(self.a) return X * self.sig @@ -331,9 +359,7 @@ class Dense(Layer): self.dcoeffs = self.dW[:self.nW].reshape(ins, outs) self.dbiases = self.dW[self.nW:].reshape(1, outs) - # he_normal initialization - s = np.sqrt(2 / ins) - self.coeffs.flat = np.random.normal(0, s, size=self.nW) + self.coeffs.flat = init_he_uniform(self.nW, ins, outs) self.biases.flat = 0 def make_shape(self, shape): @@ -432,12 +458,64 @@ class Model: for i in range(len(denses)): a, b = i, i + 1 b_name = "dense_{}".format(b) + # TODO: write a Dense method instead of assigning directly denses[a].coeffs = weights[b_name+'_W'] denses[a].biases = np.expand_dims(weights[b_name+'_b'], 0) def save_weights(self, fn, overwrite=False): raise NotImplementedError("unimplemented", self) +def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batchless'): + y = x + last_size = x.output_shape[0] + + for d in range(depth): + size = width + + if last_size != size: + y = y.feed(Dense(size)) + + if style == 'batchless': + skip = y + merger = Sum() + skip.feed(merger) + z_start = skip.feed(activation()) + for i in range(multi): + z = z_start + for i in range(block): + if i > 0: + z = z.feed(activation()) + z = z.feed(Dense(size)) + z.feed(merger) + y = merger + elif style == 'onelesssum': + is_last = d + 1 == depth + needs_sum = not is_last or multi > 1 + skip = y + if needs_sum: + merger = Sum() + if not is_last: + skip.feed(merger) + z_start = skip.feed(activation()) + for i in range(multi): + z = z_start + for i in range(block): + if i > 0: + z = z.feed(activation()) + z = z.feed(Dense(size)) + if needs_sum: + z.feed(merger) + if needs_sum: + y = merger + else: + y = z + else: + raise Exception('unknown resnet style', style) + + last_size = size + + return y + if __name__ == '__main__': import sys lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs) @@ -457,8 +535,7 @@ if __name__ == '__main__': res_multi = 1, # normally 1 for plain resnet # style of resnet (order of layers, which layers, etc.) - # only one is implemented so far - parallel_style = 'batchless', + parallel_style = 'onelesssum', activation = 'gelu', optim = 'adam', @@ -475,9 +552,14 @@ if __name__ == '__main__': # misc batch_size = 64, init = 'he_normal', - loss = 'mse', - restart_optim = False, # restarts also reset internal state of optimizer - unsafe = False, # aka gotta go fast mode + loss = SomethingElse(4/3), + log_fn = 'losses.npz', + mloss = 'mse', + restart_optim = True, # restarts also reset internal state of optimizer + unsafe = True, # aka gotta go fast mode + train_compare = None, + #valid_compare = 0.0007159, + valid_compare = 0.0000946, ) config.pprint() @@ -501,34 +583,13 @@ if __name__ == '__main__': x = Input(shape=(input_samples,)) y = x - last_size = input_samples - activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, elu=Elu, gelu=GeluApprox) activation = activations[config.activation] - - for blah in range(config.res_depth): - size = config.res_width - - if last_size != size: - y = y.feed(Dense(size)) - - assert config.parallel_style == 'batchless' - skip = y - merger = Sum() - skip.feed(merger) - z_start = skip.feed(activation()) - for i in range(config.res_multi): - z = z_start - for i in range(config.res_block): - if i > 0: - z = z.feed(activation()) - z = z.feed(Dense(size)) - z.feed(merger) - y = merger - - last_size = size - - if last_size != output_samples: + y = multiresnet(y, + config.res_width, config.res_depth, + config.res_block, config.res_multi, + activation=activation) + if y.output_shape[0] != output_samples: y = y.feed(Dense(output_samples)) model = Model(x, y, unsafe=config.unsafe) @@ -559,12 +620,17 @@ if __name__ == '__main__': else: raise Exception('unknown optimizer', config.optim) - if config.loss == 'mse': - loss = Squared() - elif config.loss == 'mshe': # mushy - loss = SquaredHalved() - else: - raise Exception('unknown objective', config.loss) + def lookup_loss(maybe_name): + if isinstance(maybe_name, Loss): + return maybe_name + elif maybe_name == 'mse': + return Squared() + elif maybe_name == 'mshe': # mushy + return SquaredHalved() + raise Exception('unknown objective', maybe_name) + + loss = lookup_loss(config.loss) + mloss = lookup_loss(config.mloss) if config.mloss else loss LR = config.LR LRprod = 0.5**(1/config.LR_halve_every) @@ -574,21 +640,34 @@ if __name__ == '__main__': # Training - def measure_loss(): - predicted = model.forward(inputs / x_scale) - residual = predicted - outputs / y_scale - err = loss.mean(residual) - log("train loss", "{:11.7f}".format(err)) - log("improvement", "{:+7.2f}%".format((0.0007031 / err - 1) * 100)) + batch_losses = [] + train_losses = [] + valid_losses = [] - predicted = model.forward(valid_inputs / x_scale) - residual = predicted - valid_outputs / y_scale - err = loss.mean(residual) - log("valid loss", "{:11.7f}".format(err)) - log("improvement", "{:+7.2f}%".format((0.0007159 / err - 1) * 100)) + def measure_error(): + # log("weight mean", "{:11.7f}".format(np.mean(model.W))) + # log("weight var", "{:11.7f}".format(np.var(model.W))) + + def print_error(name, inputs, outputs, comparison=None): + predicted = model.forward(inputs) + residual = predicted - outputs + err = mloss.mean(residual) + log(name + " loss", "{:11.7f}".format(err)) + if comparison: + log("improvement", "{:+7.2f}%".format((comparison / err - 1) * 100)) + return err + + train_err = print_error("train", + inputs / x_scale, outputs / y_scale, + config.train_compare) + valid_err = print_error("valid", + valid_inputs / x_scale, valid_outputs / y_scale, + config.valid_compare) + train_losses.append(train_err) + valid_losses.append(valid_err) for i in range(config.restarts + 1): - measure_loss() + measure_error() if i > 0: log("restarting", i) @@ -620,10 +699,12 @@ if __name__ == '__main__': optim.update(dW, model.W) # note: we don't actually need this for training, only monitoring. - cumsum_loss += loss.mean(residual) - log("average loss", "{:11.7f}".format(cumsum_loss / batch_count)) + batch_loss = mloss.mean(residual) + cumsum_loss += batch_loss + batch_losses.append(batch_loss) + #log("average loss", "{:11.7f}".format(cumsum_loss / batch_count)) - measure_loss() + measure_error() #if training: # model.save_weights(config.fn, overwrite=True) @@ -636,3 +717,9 @@ if __name__ == '__main__': P = model.forward(X) * y_scale log("truth", rgbcompare(a, b)) log("network", np.squeeze(P)) + + if config.log_fn is not None: + np.savez_compressed(config.log_fn, + batch_losses=nfa(batch_losses), + train_losses=nfa(train_losses), + valid_losses=nfa(valid_losses))