easier logging of losses etc.
This commit is contained in:
parent
1b1184480a
commit
d8bf6d1c5b
1 changed files with 30 additions and 22 deletions
52
onn_mnist.py
52
onn_mnist.py
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
from onn import *
|
from onn import *
|
||||||
from onn_core import _f
|
from onn_core import _f
|
||||||
|
from dotmap import DotMap
|
||||||
|
|
||||||
#np.random.seed(42069)
|
#np.random.seed(42069)
|
||||||
|
|
||||||
|
@ -155,11 +156,18 @@ log('parameters', model.param_count)
|
||||||
|
|
||||||
ritual.prepare(model)
|
ritual.prepare(model)
|
||||||
|
|
||||||
batch_losses, batch_mlosses = [], []
|
logs = DotMap(
|
||||||
train_losses, train_mlosses = [], []
|
batch_losses = [],
|
||||||
valid_losses, valid_mlosses = [], []
|
batch_mlosses = [],
|
||||||
|
train_losses = [],
|
||||||
train_confid, valid_confid = [], []
|
train_mlosses = [],
|
||||||
|
valid_losses = [],
|
||||||
|
valid_mlosses = [],
|
||||||
|
train_confid = [],
|
||||||
|
valid_confid = [],
|
||||||
|
learning_rate = [],
|
||||||
|
momentum = [],
|
||||||
|
)
|
||||||
|
|
||||||
def measure_error(quiet=False):
|
def measure_error(quiet=False):
|
||||||
def print_error(name, inputs, outputs, comparison=None):
|
def print_error(name, inputs, outputs, comparison=None):
|
||||||
|
@ -177,13 +185,13 @@ def measure_error(quiet=False):
|
||||||
return loss, mloss, confid
|
return loss, mloss, confid
|
||||||
|
|
||||||
loss, mloss, confid = print_error("train", inputs, outputs)
|
loss, mloss, confid = print_error("train", inputs, outputs)
|
||||||
train_losses.append(loss)
|
logs.train_losses.append(loss)
|
||||||
train_mlosses.append(mloss)
|
logs.train_mlosses.append(mloss)
|
||||||
train_confid.append(confid)
|
#logs.train_confid.append(confid)
|
||||||
loss, mloss, confid = print_error("valid", valid_inputs, valid_outputs)
|
loss, mloss, confid = print_error("valid", valid_inputs, valid_outputs)
|
||||||
valid_losses.append(loss)
|
logs.valid_losses.append(loss)
|
||||||
valid_mlosses.append(mloss)
|
logs.valid_mlosses.append(mloss)
|
||||||
valid_confid.append(confid)
|
#logs.valid_confid.append(confid)
|
||||||
|
|
||||||
measure_error()
|
measure_error()
|
||||||
|
|
||||||
|
@ -202,13 +210,17 @@ while learner.next():
|
||||||
log("epoch {}".format(learner.epoch),
|
log("epoch {}".format(learner.epoch),
|
||||||
fmt.format(learner.rate, avg_loss, avg_mloss * 100))
|
fmt.format(learner.rate, avg_loss, avg_mloss * 100))
|
||||||
|
|
||||||
batch_losses += losses
|
logs.batch_losses += losses
|
||||||
batch_mlosses += mlosses
|
logs.batch_mlosses += mlosses
|
||||||
|
|
||||||
if measure_every_epoch:
|
if measure_every_epoch:
|
||||||
quiet = learner.epoch != learner.epochs
|
quiet = learner.epoch != learner.epochs
|
||||||
measure_error(quiet=quiet)
|
measure_error(quiet=quiet)
|
||||||
|
|
||||||
|
logs.learning_rate.append(optim.alpha)
|
||||||
|
if getattr(optim, 'mu', None):
|
||||||
|
logs.momentum.append(optim.mu)
|
||||||
|
|
||||||
if not measure_every_epoch:
|
if not measure_every_epoch:
|
||||||
measure_error()
|
measure_error()
|
||||||
|
|
||||||
|
@ -218,12 +230,8 @@ if save_fn is not None:
|
||||||
|
|
||||||
if log_fn:
|
if log_fn:
|
||||||
log('saving losses', log_fn)
|
log('saving losses', log_fn)
|
||||||
np.savez_compressed(log_fn,
|
kwargs = dict()
|
||||||
batch_losses =np.array(batch_losses, dtype=_f),
|
for k, v in logs.items():
|
||||||
batch_mlosses=np.array(batch_mlosses, dtype=_f),
|
if len(v) > 0:
|
||||||
train_losses =np.array(train_losses, dtype=_f),
|
kwargs[k] = np.array(v, dtype=_f)
|
||||||
train_mlosses=np.array(train_mlosses, dtype=_f),
|
np.savez_compressed(log_fn, **kwargs)
|
||||||
valid_losses =np.array(valid_losses, dtype=_f),
|
|
||||||
valid_mlosses=np.array(valid_mlosses, dtype=_f),
|
|
||||||
train_confid =np.array(train_confid, dtype=_f),
|
|
||||||
valid_confid =np.array(valid_confid, dtype=_f))
|
|
||||||
|
|
Loading…
Reference in a new issue