diff --git a/optim_nn_mnist.py b/optim_nn_mnist.py index fc092ef..78d3f33 100755 --- a/optim_nn_mnist.py +++ b/optim_nn_mnist.py @@ -17,33 +17,48 @@ if use_emnist: sgdr = True restart_decay = 0.5 + CLR = SineCLR n_dense = 0 n_denses = 2 new_dims = (28, 28) activation = GeluApprox + reg = None + final_reg = None + actreg_lamb = None + + load_fn = None + save_fn = 'emnist.h5' log_fn = 'emnist_losses.npz' + fn = 'emnist-balanced.npz' mnist_dim = 28 mnist_classes = 47 else: lr = 0.0032 - epochs = 125 - starts = 5 + epochs = 60 + starts = 3 bs = 200 - activation = Relu - - sgdr = False + sgdr = True restart_decay = 0.5 + CLR = SineCLR - n_dense = 1 + n_dense = 2 n_denses = 1 new_dims = (4, 12) + activation = Relu + reg = L1L2(3.2e-5, 3.2e-4) + final_reg = L1L2(3.2e-5, 1e-3) + actreg_lamb = None # 1e-3 + + load_fn = None + save_fn = 'mnist.h5' log_fn = 'mnist_losses.npz' + fn = 'mnist.npz' mnist_dim = 28 mnist_classes = 10 @@ -74,23 +89,39 @@ def get_mnist(fn='mnist.npz'): inputs, outputs, valid_inputs, valid_outputs = get_mnist(fn) +def actreg(y): + if not actreg_lamb: + return y + lamb = actreg_lamb # * np.prod(y.output_shape) + reg = SaturateRelu(lamb) + act = ActivityRegularizer(reg) + reg.lamb_orig = reg.lamb # HACK + return y.feed(act) + x = Input(shape=inputs.shape[1:]) y = x y = y.feed(Reshape(new_shape=(mnist_dim, mnist_dim))) for i in range(n_denses): if i > 0: + y = actreg(y) y = y.feed(activation()) - y = y.feed(Denses(new_dims[0], axis=0, init=init_he_normal)) - y = y.feed(Denses(new_dims[1], axis=1, init=init_he_normal)) + y = y.feed(Denses(new_dims[0], axis=0, init=init_he_normal, + reg_w=reg, reg_b=reg)) + y = y.feed(Denses(new_dims[1], axis=1, init=init_he_normal, + reg_w=reg, reg_b=reg)) y = y.feed(Flatten()) for i in range(n_dense): if i > 0: + y = actreg(y) y = y.feed(activation()) - y = y.feed(Dense(y.output_shape[0], init=init_he_normal)) + y = y.feed(Dense(y.output_shape[0], init=init_he_normal, + reg_w=reg, reg_b=reg)) +y = actreg(y) y = y.feed(activation()) -y = y.feed(Dense(mnist_classes, init=init_glorot_uniform)) +y = y.feed(Dense(mnist_classes, init=init_glorot_uniform, + reg_w=final_reg, reg_b=final_reg)) y = y.feed(Softmax()) model = Model(x, y, unsafe=True) @@ -101,10 +132,8 @@ if sgdr: restarts=starts-1, restart_decay=restart_decay, expando=lambda i:0) else: -# learner = TriangularCLR(optim, epochs=epochs, lower_rate=0, upper_rate=lr, -# frequency=epochs//starts) - learner = SineCLR(optim, epochs=epochs, lower_rate=0, upper_rate=lr, - frequency=epochs//starts) + learner = CLR(optim, epochs=epochs, lower_rate=0, upper_rate=lr, + frequency=epochs//starts) loss = CategoricalCrossentropy() mloss = Accuracy() @@ -113,6 +142,11 @@ ritual = Ritual(learner=learner, loss=loss, mloss=mloss) #ritual = NoisyRitual(learner=learner, loss=loss, mloss=mloss, # input_noise=1e-1, output_noise=3.2e-2, gradient_noise=1e-1) +for node in model.ordered_nodes: + children = [str(n) for n in node.children] + if children: + sep = '->' + print(str(node) + sep + ('\n' + str(node) + sep).join(children)) log('parameters', model.param_count) ritual.prepare(model) @@ -151,6 +185,12 @@ def measure_error(quiet=False): measure_error() while learner.next(): + act_t = (learner.epoch - 1) / (learner.epochs - 1) + if actreg_lamb: + for node in model.ordered_nodes: + if isinstance(node, ActivityRegularizer): + node.reg.lamb = act_t * node.reg.lamb_orig # HACK + indices = np.arange(inputs.shape[0]) np.random.shuffle(indices) shuffled_inputs = inputs[indices] @@ -174,6 +214,10 @@ while learner.next(): if not measure_every_epoch: measure_error() +if save_fn is not None: + log('saving weights', save_fn) + model.save_weights(save_fn, overwrite=True) + if log_fn: log('saving losses', log_fn) np.savez_compressed(log_fn,