.
This commit is contained in:
parent
fcdb7e1918
commit
65fe5cad85
2 changed files with 41 additions and 9 deletions
|
@ -723,7 +723,7 @@ class Ritual: # i'm just making up names at this point
|
|||
batch_outputs = outputs[bi:bi+batch_size]
|
||||
|
||||
if not test_only and self.learner.per_batch:
|
||||
self.learner.batch(b / batch_count)
|
||||
self.learner.batch(b / batch_count)
|
||||
|
||||
predicted = self.learn(batch_inputs, batch_outputs)
|
||||
if not test_only:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from optim_nn import *
|
||||
from optim_nn_core import _f
|
||||
|
||||
#np.random.seed(42069)
|
||||
|
||||
|
@ -13,8 +14,12 @@ from optim_nn import *
|
|||
lr = 0.01
|
||||
epochs = 24
|
||||
starts = 2
|
||||
restart_decay = 0.5
|
||||
bs = 100
|
||||
|
||||
log_fn = 'mnist_losses.npz'
|
||||
measure_every_epoch = True
|
||||
|
||||
mnist_dim = 28
|
||||
mnist_classes = 10
|
||||
def get_mnist(fn='mnist.npz'):
|
||||
|
@ -60,7 +65,7 @@ model = Model(x, y, unsafe=True)
|
|||
|
||||
optim = Adam()
|
||||
learner = SGDR(optim, epochs=epochs//starts, rate=lr,
|
||||
restarts=starts - 1, restart_decay=0.5,
|
||||
restarts=starts - 1, restart_decay=restart_decay,
|
||||
expando=lambda i:0)
|
||||
|
||||
loss = CategoricalCrossentropy()
|
||||
|
@ -72,15 +77,24 @@ log('parameters', model.param_count)
|
|||
|
||||
ritual.prepare(model)
|
||||
|
||||
def measure_error():
|
||||
batch_losses, batch_mlosses = [], []
|
||||
train_losses, train_mlosses = [], []
|
||||
valid_losses, valid_mlosses = [], []
|
||||
|
||||
def measure_error(quiet=False):
|
||||
def print_error(name, inputs, outputs, comparison=None):
|
||||
loss, mloss, _, _ = ritual.test_batched(inputs, outputs, bs, return_losses='both')
|
||||
log(name + " loss", "{:12.6e}".format(loss))
|
||||
log(name + " accuracy", "{:6.2f}%".format(mloss * 100))
|
||||
if not quiet:
|
||||
log(name + " loss", "{:12.6e}".format(loss))
|
||||
log(name + " accuracy", "{:6.2f}%".format(mloss * 100))
|
||||
return loss, mloss
|
||||
|
||||
print_error("train", inputs, outputs)
|
||||
print_error("valid", valid_inputs, valid_outputs)
|
||||
loss, mloss = print_error("train", inputs, outputs)
|
||||
train_losses.append(loss)
|
||||
train_mlosses.append(mloss)
|
||||
loss, mloss = print_error("valid", valid_inputs, valid_outputs)
|
||||
valid_losses.append(loss)
|
||||
valid_mlosses.append(mloss)
|
||||
|
||||
measure_error()
|
||||
|
||||
|
@ -90,7 +104,7 @@ while learner.next():
|
|||
shuffled_inputs = inputs[indices]
|
||||
shuffled_outputs = outputs[indices]
|
||||
|
||||
avg_loss, avg_mloss, _, _ = ritual.train_batched(
|
||||
avg_loss, avg_mloss, losses, mlosses = ritual.train_batched(
|
||||
shuffled_inputs, shuffled_outputs,
|
||||
batch_size=bs,
|
||||
return_losses='both')
|
||||
|
@ -98,4 +112,22 @@ while learner.next():
|
|||
log("epoch {}".format(learner.epoch + 1),
|
||||
fmt.format(learner.rate, avg_loss, avg_mloss * 100))
|
||||
|
||||
measure_error()
|
||||
batch_losses += losses
|
||||
batch_mlosses += mlosses
|
||||
|
||||
if measure_every_epoch:
|
||||
quiet = learner.epoch + 1 != learner.epochs
|
||||
measure_error(quiet=quiet)
|
||||
|
||||
if not measure_every_epoch:
|
||||
measure_error()
|
||||
|
||||
if log_fn:
|
||||
log('saving losses', log_fn)
|
||||
np.savez_compressed(log_fn,
|
||||
batch_losses =np.array(batch_losses, dtype=_f),
|
||||
batch_mlosses=np.array(batch_mlosses, dtype=_f),
|
||||
train_losses =np.array(train_losses, dtype=_f),
|
||||
train_mlosses=np.array(train_mlosses, dtype=_f),
|
||||
valid_losses =np.array(valid_losses, dtype=_f),
|
||||
valid_mlosses=np.array(valid_mlosses, dtype=_f))
|
||||
|
|
Loading…
Reference in a new issue