update mnist example

This commit is contained in:
Connor Olding 2017-07-23 04:23:57 +00:00
parent 5183cd38f8
commit 6933e21e0e

View File

@ -5,9 +5,9 @@ from onn_core import _f
from dotmap import DotMap
lower_priority()
#np.random.seed(42069)
np.random.seed(42069)
use_emnist = True
use_emnist = False
measure_every_epoch = True
@ -16,6 +16,7 @@ if use_emnist:
epochs = 48
starts = 2
bs = 400
lr *= np.sqrt(bs)
learner_class = SGDR
restart_decay = 0.5
@ -24,6 +25,10 @@ if use_emnist:
n_denses = 0
new_dims = (28, 28)
activation = GeluApprox
output_activation = Softmax
normalize = True
optim = Adam()
reg = None # L1L2(2.0e-5, 1.0e-4)
final_reg = None # L1L2(2.0e-5, 1.0e-4)
@ -43,14 +48,19 @@ else:
epochs = 60
starts = 3
bs = 500
lr *= np.sqrt(bs)
learner_class = None #SGDR
learner_class = SGDR
restart_decay = 0.5
n_dense = 1
n_denses = 0
n_dense = 2
n_denses = 1
new_dims = (4, 12)
activation = Relu # GeluApprox
activation = GeluApprox
output_activation = Softmax
normalize = True
optim = MomentumClip(0.8, 0.8)
reg = None # L1L2(1e-6, 1e-5) # L1L2(3.2e-5, 3.2e-4)
final_reg = None # L1L2(1e-6, 1e-5) # L1L2(3.2e-5, 1e-3)
@ -58,8 +68,8 @@ else:
actreg_lamb = None #1e-4
load_fn = None
save_fn = None # 'mnist.h5'
log_fn = 'mnist_losses6.npz'
save_fn = 'mnist.h5'
log_fn = 'mnist_losses.npz'
fn = 'mnist.npz'
mnist_dim = 28
@ -91,6 +101,9 @@ def get_mnist(fn='mnist.npz'):
inputs, outputs, valid_inputs, valid_outputs = get_mnist(fn)
outputs = target_boost(outputs)
valid_outputs = target_boost(valid_outputs)
def regulate(y):
if actreg_lamb:
assert activation == Relu, activation
@ -99,7 +112,8 @@ def regulate(y):
act = ActivityRegularizer(reg)
reg.lamb_orig = reg.lamb # HACK
y = y.feed(act)
y = y.feed(LayerNorm())
if normalize:
y = y.feed(LayerNorm())
if dropout:
y = y.feed(Dropout(dropout))
return y
@ -128,13 +142,10 @@ y = y.feed(activation())
y = y.feed(Dense(mnist_classes, init=init_glorot_uniform,
reg_w=final_reg, reg_b=final_reg))
y = y.feed(Softmax())
y = y.feed(output_activation())
model = Model(x, y, unsafe=True)
lr *= np.sqrt(bs)
optim = MomentumClip(0.8, 0.8)
if learner_class == SGDR:
learner = learner_class(optim, epochs=epochs//starts, rate=lr,
restarts=starts-1, restart_decay=restart_decay,
@ -142,11 +153,21 @@ if learner_class == SGDR:
elif learner_class in (TriangularCLR, SineCLR, WaveCLR):
learner = learner_class(optim, epochs=epochs, lower_rate=0, upper_rate=lr,
frequency=epochs//starts)
elif learner_class is AnnealingLearner:
learner = learner_class(optim, epochs=epochs, rate=lr,
halve_every=epochs//starts)
elif learner_class is DumbLearner:
learner = learner_class(self, optim, epochs=epochs//starts, rate=lr,
halve_every=epochs//(2*starts),
restarts=starts-1, restart_advance=epochs//starts)
elif learner_class is Learner:
learner = Learner(optim, epochs=epochs, rate=lr)
else:
lament('NOTE: no learning rate schedule selected.')
if not isinstance(optim, YellowFin):
lament('WARNING: no learning rate schedule selected.')
learner = Learner(optim, epochs=epochs)
loss = CategoricalCrossentropy()
loss = CategoricalCrossentropy() if output_activation == Softmax else SquaredHalved()
mloss = Accuracy()
ritual = Ritual(learner=learner, loss=loss, mloss=mloss)
@ -165,8 +186,8 @@ logs = DotMap(
train_mlosses = [],
valid_losses = [],
valid_mlosses = [],
train_confid = [],
valid_confid = [],
#train_confid = [],
#valid_confid = [],
learning_rate = [],
momentum = [],
)