tweak and fix
This commit is contained in:
parent
c49e498aa0
commit
5cb28eeef7
1 changed files with 15 additions and 16 deletions
|
@ -15,9 +15,8 @@ if use_emnist:
|
|||
starts = 2
|
||||
bs = 200
|
||||
|
||||
sgdr = True
|
||||
learner_class = SGDR
|
||||
restart_decay = 0.5
|
||||
CLR = SineCLR
|
||||
|
||||
n_dense = 0
|
||||
n_denses = 2
|
||||
|
@ -38,24 +37,23 @@ if use_emnist:
|
|||
mnist_classes = 47
|
||||
|
||||
else:
|
||||
lr = 0.0032
|
||||
lr = 0.01
|
||||
epochs = 60
|
||||
starts = 3
|
||||
bs = 200
|
||||
bs = 500
|
||||
|
||||
sgdr = True
|
||||
learner_class = SGDR
|
||||
restart_decay = 0.5
|
||||
CLR = SineCLR
|
||||
|
||||
n_dense = 2
|
||||
n_denses = 1
|
||||
n_denses = 0
|
||||
new_dims = (4, 12)
|
||||
activation = Relu
|
||||
|
||||
reg = L1L2(3.2e-5, 3.2e-4)
|
||||
final_reg = L1L2(3.2e-5, 1e-3)
|
||||
dropout = 0.10
|
||||
actreg_lamb = None # 1e-3
|
||||
dropout = 0.05
|
||||
actreg_lamb = None #1e-4
|
||||
|
||||
load_fn = None
|
||||
save_fn = 'mnist.h5'
|
||||
|
@ -93,7 +91,7 @@ inputs, outputs, valid_inputs, valid_outputs = get_mnist(fn)
|
|||
|
||||
def regulate(y):
|
||||
if actreg_lamb:
|
||||
assert type(activation) == Relu, type(activation)
|
||||
assert activation == Relu, activation
|
||||
lamb = actreg_lamb # * np.prod(y.output_shape)
|
||||
reg = SaturateRelu(lamb)
|
||||
act = ActivityRegularizer(reg)
|
||||
|
@ -132,13 +130,14 @@ y = y.feed(Softmax())
|
|||
model = Model(x, y, unsafe=True)
|
||||
|
||||
optim = Adam()
|
||||
if sgdr:
|
||||
learner = SGDR(optim, epochs=epochs//starts, rate=lr,
|
||||
restarts=starts-1, restart_decay=restart_decay,
|
||||
expando=lambda i:0)
|
||||
if learner_class == SGDR:
|
||||
learner = learner_class(optim, epochs=epochs//starts, rate=lr,
|
||||
restarts=starts-1, restart_decay=restart_decay,
|
||||
expando=lambda i:0)
|
||||
else:
|
||||
learner = CLR(optim, epochs=epochs, lower_rate=0, upper_rate=lr,
|
||||
frequency=epochs//starts)
|
||||
assert learner_class in (TriangularCLR, SineCLR, WaveCLR)
|
||||
learner = learner_class(optim, epochs=epochs, lower_rate=0, upper_rate=lr,
|
||||
frequency=epochs//starts)
|
||||
|
||||
loss = CategoricalCrossentropy()
|
||||
mloss = Accuracy()
|
||||
|
|
Loading…
Reference in a new issue