update mnist training

crank up the learning rate on emnist and use momentum with gradient clipping.
add a simple restart callback.
remove batch size adapation crap.
remove confidence measures.
This commit is contained in:
Connor Olding 2017-08-03 03:35:02 +00:00
parent 7ac67fba8f
commit 9138f73141

View file

@ -7,16 +7,16 @@ from dotmap import DotMap
lower_priority() lower_priority()
np.random.seed(42069) np.random.seed(42069)
use_emnist = False
measure_every_epoch = True measure_every_epoch = True
target_boost = lambda y: y
use_emnist = False
if use_emnist: if use_emnist:
lr = 0.005 lr = 1.0
epochs = 48 epochs = 48
starts = 2 starts = 2
bs = 400 bs = 400
lr *= np.sqrt(bs)
learner_class = SGDR learner_class = SGDR
restart_decay = 0.5 restart_decay = 0.5
@ -28,7 +28,8 @@ if use_emnist:
output_activation = Softmax output_activation = Softmax
normalize = True normalize = True
optim = Adam() optim = MomentumClip(mu=0.7, nesterov=True)
restart_optim = False
reg = None # L1L2(2.0e-5, 1.0e-4) reg = None # L1L2(2.0e-5, 1.0e-4)
final_reg = None # L1L2(2.0e-5, 1.0e-4) final_reg = None # L1L2(2.0e-5, 1.0e-4)
@ -44,11 +45,10 @@ if use_emnist:
mnist_classes = 47 mnist_classes = 47
else: else:
lr = 0.0005 lr = 0.01
epochs = 60 epochs = 60
starts = 3 starts = 3
bs = 500 bs = 500
lr *= np.sqrt(bs)
learner_class = SGDR learner_class = SGDR
restart_decay = 0.5 restart_decay = 0.5
@ -61,6 +61,7 @@ else:
normalize = True normalize = True
optim = MomentumClip(0.8, 0.8) optim = MomentumClip(0.8, 0.8)
restart_optim = False
reg = None # L1L2(1e-6, 1e-5) # L1L2(3.2e-5, 3.2e-4) reg = None # L1L2(1e-6, 1e-5) # L1L2(3.2e-5, 3.2e-4)
final_reg = None # L1L2(1e-6, 1e-5) # L1L2(3.2e-5, 1e-3) final_reg = None # L1L2(1e-6, 1e-5) # L1L2(3.2e-5, 1e-3)
@ -146,20 +147,28 @@ y = y.feed(output_activation())
model = Model(x, y, unsafe=True) model = Model(x, y, unsafe=True)
def rscb(restart):
log("restarting", restart)
if restart_optim:
optim.reset()
if learner_class == SGDR: if learner_class == SGDR:
learner = learner_class(optim, epochs=epochs//starts, rate=lr, learner = learner_class(optim, epochs=epochs//starts, rate=lr,
restarts=starts-1, restart_decay=restart_decay, restarts=starts-1, restart_decay=restart_decay,
expando=lambda i:0) expando=lambda i:0,
callback=rscb)
elif learner_class in (TriangularCLR, SineCLR, WaveCLR): elif learner_class in (TriangularCLR, SineCLR, WaveCLR):
learner = learner_class(optim, epochs=epochs, lower_rate=0, upper_rate=lr, learner = learner_class(optim, epochs=epochs, lower_rate=0, upper_rate=lr,
frequency=epochs//starts) frequency=epochs//starts,
callback=rscb)
elif learner_class is AnnealingLearner: elif learner_class is AnnealingLearner:
learner = learner_class(optim, epochs=epochs, rate=lr, learner = learner_class(optim, epochs=epochs, rate=lr,
halve_every=epochs//starts) halve_every=epochs//starts)
elif learner_class is DumbLearner: elif learner_class is DumbLearner:
learner = learner_class(self, optim, epochs=epochs//starts, rate=lr, learner = learner_class(self, optim, epochs=epochs//starts, rate=lr,
halve_every=epochs//(2*starts), halve_every=epochs//(2*starts),
restarts=starts-1, restart_advance=epochs//starts) restarts=starts-1, restart_advance=epochs//starts,
callback=rscb)
elif learner_class is Learner: elif learner_class is Learner:
learner = Learner(optim, epochs=epochs, rate=lr) learner = Learner(optim, epochs=epochs, rate=lr)
else: else:
@ -186,8 +195,6 @@ logs = DotMap(
train_mlosses = [], train_mlosses = [],
valid_losses = [], valid_losses = [],
valid_mlosses = [], valid_mlosses = [],
#train_confid = [],
#valid_confid = [],
learning_rate = [], learning_rate = [],
momentum = [], momentum = [],
) )
@ -196,25 +203,18 @@ def measure_error(quiet=False):
def print_error(name, inputs, outputs, comparison=None): def print_error(name, inputs, outputs, comparison=None):
loss, mloss, _, _ = ritual.test_batched(inputs, outputs, bs, return_losses='both') loss, mloss, _, _ = ritual.test_batched(inputs, outputs, bs, return_losses='both')
c = Confidence()
predicted = ritual.model.forward(inputs, deterministic=True)
confid = c.forward(predicted)
if not quiet: if not quiet:
log(name + " loss", "{:12.6e}".format(loss)) log(name + " loss", "{:12.6e}".format(loss))
log(name + " accuracy", "{:6.2f}%".format(mloss * 100)) log(name + " accuracy", "{:6.2f}%".format(mloss * 100))
log(name + " confidence", "{:6.2f}%".format(confid * 100))
return loss, mloss, confid return loss, mloss
loss, mloss, confid = print_error("train", inputs, outputs) loss, mloss = print_error("train", inputs, outputs)
logs.train_losses.append(loss) logs.train_losses.append(loss)
logs.train_mlosses.append(mloss) logs.train_mlosses.append(mloss)
#logs.train_confid.append(confid) loss, mloss = print_error("valid", valid_inputs, valid_outputs)
loss, mloss, confid = print_error("valid", valid_inputs, valid_outputs)
logs.valid_losses.append(loss) logs.valid_losses.append(loss)
logs.valid_mlosses.append(mloss) logs.valid_mlosses.append(mloss)
#logs.valid_confid.append(confid)
measure_error() measure_error()