diff --git a/optim_nn_mnist.py b/optim_nn_mnist.py index 7749786..900ef15 100755 --- a/optim_nn_mnist.py +++ b/optim_nn_mnist.py @@ -15,9 +15,8 @@ if use_emnist: starts = 2 bs = 200 - sgdr = True + learner_class = SGDR restart_decay = 0.5 - CLR = SineCLR n_dense = 0 n_denses = 2 @@ -38,24 +37,23 @@ if use_emnist: mnist_classes = 47 else: - lr = 0.0032 + lr = 0.01 epochs = 60 starts = 3 - bs = 200 + bs = 500 - sgdr = True + learner_class = SGDR restart_decay = 0.5 - CLR = SineCLR n_dense = 2 - n_denses = 1 + n_denses = 0 new_dims = (4, 12) activation = Relu reg = L1L2(3.2e-5, 3.2e-4) final_reg = L1L2(3.2e-5, 1e-3) - dropout = 0.10 - actreg_lamb = None # 1e-3 + dropout = 0.05 + actreg_lamb = None #1e-4 load_fn = None save_fn = 'mnist.h5' @@ -93,7 +91,7 @@ inputs, outputs, valid_inputs, valid_outputs = get_mnist(fn) def regulate(y): if actreg_lamb: - assert type(activation) == Relu, type(activation) + assert activation == Relu, activation lamb = actreg_lamb # * np.prod(y.output_shape) reg = SaturateRelu(lamb) act = ActivityRegularizer(reg) @@ -132,13 +130,14 @@ y = y.feed(Softmax()) model = Model(x, y, unsafe=True) optim = Adam() -if sgdr: - learner = SGDR(optim, epochs=epochs//starts, rate=lr, - restarts=starts-1, restart_decay=restart_decay, - expando=lambda i:0) +if learner_class == SGDR: + learner = learner_class(optim, epochs=epochs//starts, rate=lr, + restarts=starts-1, restart_decay=restart_decay, + expando=lambda i:0) else: - learner = CLR(optim, epochs=epochs, lower_rate=0, upper_rate=lr, - frequency=epochs//starts) + assert learner_class in (TriangularCLR, SineCLR, WaveCLR) + learner = learner_class(optim, epochs=epochs, lower_rate=0, upper_rate=lr, + frequency=epochs//starts) loss = CategoricalCrossentropy() mloss = Accuracy()