update mnist example
This commit is contained in:
parent
5183cd38f8
commit
6933e21e0e
1 changed files with 38 additions and 17 deletions
55
onn_mnist.py
55
onn_mnist.py
|
@ -5,9 +5,9 @@ from onn_core import _f
|
|||
from dotmap import DotMap
|
||||
|
||||
lower_priority()
|
||||
#np.random.seed(42069)
|
||||
np.random.seed(42069)
|
||||
|
||||
use_emnist = True
|
||||
use_emnist = False
|
||||
|
||||
measure_every_epoch = True
|
||||
|
||||
|
@ -16,6 +16,7 @@ if use_emnist:
|
|||
epochs = 48
|
||||
starts = 2
|
||||
bs = 400
|
||||
lr *= np.sqrt(bs)
|
||||
|
||||
learner_class = SGDR
|
||||
restart_decay = 0.5
|
||||
|
@ -24,6 +25,10 @@ if use_emnist:
|
|||
n_denses = 0
|
||||
new_dims = (28, 28)
|
||||
activation = GeluApprox
|
||||
output_activation = Softmax
|
||||
normalize = True
|
||||
|
||||
optim = Adam()
|
||||
|
||||
reg = None # L1L2(2.0e-5, 1.0e-4)
|
||||
final_reg = None # L1L2(2.0e-5, 1.0e-4)
|
||||
|
@ -43,14 +48,19 @@ else:
|
|||
epochs = 60
|
||||
starts = 3
|
||||
bs = 500
|
||||
lr *= np.sqrt(bs)
|
||||
|
||||
learner_class = None #SGDR
|
||||
learner_class = SGDR
|
||||
restart_decay = 0.5
|
||||
|
||||
n_dense = 1
|
||||
n_denses = 0
|
||||
n_dense = 2
|
||||
n_denses = 1
|
||||
new_dims = (4, 12)
|
||||
activation = Relu # GeluApprox
|
||||
activation = GeluApprox
|
||||
output_activation = Softmax
|
||||
normalize = True
|
||||
|
||||
optim = MomentumClip(0.8, 0.8)
|
||||
|
||||
reg = None # L1L2(1e-6, 1e-5) # L1L2(3.2e-5, 3.2e-4)
|
||||
final_reg = None # L1L2(1e-6, 1e-5) # L1L2(3.2e-5, 1e-3)
|
||||
|
@ -58,8 +68,8 @@ else:
|
|||
actreg_lamb = None #1e-4
|
||||
|
||||
load_fn = None
|
||||
save_fn = None # 'mnist.h5'
|
||||
log_fn = 'mnist_losses6.npz'
|
||||
save_fn = 'mnist.h5'
|
||||
log_fn = 'mnist_losses.npz'
|
||||
|
||||
fn = 'mnist.npz'
|
||||
mnist_dim = 28
|
||||
|
@ -91,6 +101,9 @@ def get_mnist(fn='mnist.npz'):
|
|||
|
||||
inputs, outputs, valid_inputs, valid_outputs = get_mnist(fn)
|
||||
|
||||
outputs = target_boost(outputs)
|
||||
valid_outputs = target_boost(valid_outputs)
|
||||
|
||||
def regulate(y):
|
||||
if actreg_lamb:
|
||||
assert activation == Relu, activation
|
||||
|
@ -99,7 +112,8 @@ def regulate(y):
|
|||
act = ActivityRegularizer(reg)
|
||||
reg.lamb_orig = reg.lamb # HACK
|
||||
y = y.feed(act)
|
||||
y = y.feed(LayerNorm())
|
||||
if normalize:
|
||||
y = y.feed(LayerNorm())
|
||||
if dropout:
|
||||
y = y.feed(Dropout(dropout))
|
||||
return y
|
||||
|
@ -128,13 +142,10 @@ y = y.feed(activation())
|
|||
|
||||
y = y.feed(Dense(mnist_classes, init=init_glorot_uniform,
|
||||
reg_w=final_reg, reg_b=final_reg))
|
||||
y = y.feed(Softmax())
|
||||
y = y.feed(output_activation())
|
||||
|
||||
model = Model(x, y, unsafe=True)
|
||||
|
||||
lr *= np.sqrt(bs)
|
||||
|
||||
optim = MomentumClip(0.8, 0.8)
|
||||
if learner_class == SGDR:
|
||||
learner = learner_class(optim, epochs=epochs//starts, rate=lr,
|
||||
restarts=starts-1, restart_decay=restart_decay,
|
||||
|
@ -142,11 +153,21 @@ if learner_class == SGDR:
|
|||
elif learner_class in (TriangularCLR, SineCLR, WaveCLR):
|
||||
learner = learner_class(optim, epochs=epochs, lower_rate=0, upper_rate=lr,
|
||||
frequency=epochs//starts)
|
||||
elif learner_class is AnnealingLearner:
|
||||
learner = learner_class(optim, epochs=epochs, rate=lr,
|
||||
halve_every=epochs//starts)
|
||||
elif learner_class is DumbLearner:
|
||||
learner = learner_class(self, optim, epochs=epochs//starts, rate=lr,
|
||||
halve_every=epochs//(2*starts),
|
||||
restarts=starts-1, restart_advance=epochs//starts)
|
||||
elif learner_class is Learner:
|
||||
learner = Learner(optim, epochs=epochs, rate=lr)
|
||||
else:
|
||||
lament('NOTE: no learning rate schedule selected.')
|
||||
if not isinstance(optim, YellowFin):
|
||||
lament('WARNING: no learning rate schedule selected.')
|
||||
learner = Learner(optim, epochs=epochs)
|
||||
|
||||
loss = CategoricalCrossentropy()
|
||||
loss = CategoricalCrossentropy() if output_activation == Softmax else SquaredHalved()
|
||||
mloss = Accuracy()
|
||||
|
||||
ritual = Ritual(learner=learner, loss=loss, mloss=mloss)
|
||||
|
@ -165,8 +186,8 @@ logs = DotMap(
|
|||
train_mlosses = [],
|
||||
valid_losses = [],
|
||||
valid_mlosses = [],
|
||||
train_confid = [],
|
||||
valid_confid = [],
|
||||
#train_confid = [],
|
||||
#valid_confid = [],
|
||||
learning_rate = [],
|
||||
momentum = [],
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue