update mnist training
crank up the learning rate on emnist and use momentum with gradient clipping. add a simple restart callback. remove batch size adapation crap. remove confidence measures.
This commit is contained in:
parent
7ac67fba8f
commit
9138f73141
1 changed files with 22 additions and 22 deletions
44
onn_mnist.py
44
onn_mnist.py
|
@ -7,16 +7,16 @@ from dotmap import DotMap
|
|||
lower_priority()
|
||||
np.random.seed(42069)
|
||||
|
||||
use_emnist = False
|
||||
|
||||
measure_every_epoch = True
|
||||
|
||||
target_boost = lambda y: y
|
||||
|
||||
use_emnist = False
|
||||
if use_emnist:
|
||||
lr = 0.005
|
||||
lr = 1.0
|
||||
epochs = 48
|
||||
starts = 2
|
||||
bs = 400
|
||||
lr *= np.sqrt(bs)
|
||||
|
||||
learner_class = SGDR
|
||||
restart_decay = 0.5
|
||||
|
@ -28,7 +28,8 @@ if use_emnist:
|
|||
output_activation = Softmax
|
||||
normalize = True
|
||||
|
||||
optim = Adam()
|
||||
optim = MomentumClip(mu=0.7, nesterov=True)
|
||||
restart_optim = False
|
||||
|
||||
reg = None # L1L2(2.0e-5, 1.0e-4)
|
||||
final_reg = None # L1L2(2.0e-5, 1.0e-4)
|
||||
|
@ -44,11 +45,10 @@ if use_emnist:
|
|||
mnist_classes = 47
|
||||
|
||||
else:
|
||||
lr = 0.0005
|
||||
lr = 0.01
|
||||
epochs = 60
|
||||
starts = 3
|
||||
bs = 500
|
||||
lr *= np.sqrt(bs)
|
||||
|
||||
learner_class = SGDR
|
||||
restart_decay = 0.5
|
||||
|
@ -61,6 +61,7 @@ else:
|
|||
normalize = True
|
||||
|
||||
optim = MomentumClip(0.8, 0.8)
|
||||
restart_optim = False
|
||||
|
||||
reg = None # L1L2(1e-6, 1e-5) # L1L2(3.2e-5, 3.2e-4)
|
||||
final_reg = None # L1L2(1e-6, 1e-5) # L1L2(3.2e-5, 1e-3)
|
||||
|
@ -146,20 +147,28 @@ y = y.feed(output_activation())
|
|||
|
||||
model = Model(x, y, unsafe=True)
|
||||
|
||||
def rscb(restart):
|
||||
log("restarting", restart)
|
||||
if restart_optim:
|
||||
optim.reset()
|
||||
|
||||
if learner_class == SGDR:
|
||||
learner = learner_class(optim, epochs=epochs//starts, rate=lr,
|
||||
restarts=starts-1, restart_decay=restart_decay,
|
||||
expando=lambda i:0)
|
||||
expando=lambda i:0,
|
||||
callback=rscb)
|
||||
elif learner_class in (TriangularCLR, SineCLR, WaveCLR):
|
||||
learner = learner_class(optim, epochs=epochs, lower_rate=0, upper_rate=lr,
|
||||
frequency=epochs//starts)
|
||||
frequency=epochs//starts,
|
||||
callback=rscb)
|
||||
elif learner_class is AnnealingLearner:
|
||||
learner = learner_class(optim, epochs=epochs, rate=lr,
|
||||
halve_every=epochs//starts)
|
||||
elif learner_class is DumbLearner:
|
||||
learner = learner_class(self, optim, epochs=epochs//starts, rate=lr,
|
||||
halve_every=epochs//(2*starts),
|
||||
restarts=starts-1, restart_advance=epochs//starts)
|
||||
restarts=starts-1, restart_advance=epochs//starts,
|
||||
callback=rscb)
|
||||
elif learner_class is Learner:
|
||||
learner = Learner(optim, epochs=epochs, rate=lr)
|
||||
else:
|
||||
|
@ -186,8 +195,6 @@ logs = DotMap(
|
|||
train_mlosses = [],
|
||||
valid_losses = [],
|
||||
valid_mlosses = [],
|
||||
#train_confid = [],
|
||||
#valid_confid = [],
|
||||
learning_rate = [],
|
||||
momentum = [],
|
||||
)
|
||||
|
@ -196,25 +203,18 @@ def measure_error(quiet=False):
|
|||
def print_error(name, inputs, outputs, comparison=None):
|
||||
loss, mloss, _, _ = ritual.test_batched(inputs, outputs, bs, return_losses='both')
|
||||
|
||||
c = Confidence()
|
||||
predicted = ritual.model.forward(inputs, deterministic=True)
|
||||
confid = c.forward(predicted)
|
||||
|
||||
if not quiet:
|
||||
log(name + " loss", "{:12.6e}".format(loss))
|
||||
log(name + " accuracy", "{:6.2f}%".format(mloss * 100))
|
||||
log(name + " confidence", "{:6.2f}%".format(confid * 100))
|
||||
|
||||
return loss, mloss, confid
|
||||
return loss, mloss
|
||||
|
||||
loss, mloss, confid = print_error("train", inputs, outputs)
|
||||
loss, mloss = print_error("train", inputs, outputs)
|
||||
logs.train_losses.append(loss)
|
||||
logs.train_mlosses.append(mloss)
|
||||
#logs.train_confid.append(confid)
|
||||
loss, mloss, confid = print_error("valid", valid_inputs, valid_outputs)
|
||||
loss, mloss = print_error("valid", valid_inputs, valid_outputs)
|
||||
logs.valid_losses.append(loss)
|
||||
logs.valid_mlosses.append(mloss)
|
||||
#logs.valid_confid.append(confid)
|
||||
|
||||
measure_error()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue