easier logging of losses etc.

This commit is contained in:
Connor Olding 2017-07-02 02:53:31 +00:00
parent 1b1184480a
commit d8bf6d1c5b

View file

@ -2,6 +2,7 @@
from onn import * from onn import *
from onn_core import _f from onn_core import _f
from dotmap import DotMap
#np.random.seed(42069) #np.random.seed(42069)
@ -155,11 +156,18 @@ log('parameters', model.param_count)
ritual.prepare(model) ritual.prepare(model)
batch_losses, batch_mlosses = [], [] logs = DotMap(
train_losses, train_mlosses = [], [] batch_losses = [],
valid_losses, valid_mlosses = [], [] batch_mlosses = [],
train_losses = [],
train_confid, valid_confid = [], [] train_mlosses = [],
valid_losses = [],
valid_mlosses = [],
train_confid = [],
valid_confid = [],
learning_rate = [],
momentum = [],
)
def measure_error(quiet=False): def measure_error(quiet=False):
def print_error(name, inputs, outputs, comparison=None): def print_error(name, inputs, outputs, comparison=None):
@ -177,13 +185,13 @@ def measure_error(quiet=False):
return loss, mloss, confid return loss, mloss, confid
loss, mloss, confid = print_error("train", inputs, outputs) loss, mloss, confid = print_error("train", inputs, outputs)
train_losses.append(loss) logs.train_losses.append(loss)
train_mlosses.append(mloss) logs.train_mlosses.append(mloss)
train_confid.append(confid) #logs.train_confid.append(confid)
loss, mloss, confid = print_error("valid", valid_inputs, valid_outputs) loss, mloss, confid = print_error("valid", valid_inputs, valid_outputs)
valid_losses.append(loss) logs.valid_losses.append(loss)
valid_mlosses.append(mloss) logs.valid_mlosses.append(mloss)
valid_confid.append(confid) #logs.valid_confid.append(confid)
measure_error() measure_error()
@ -202,13 +210,17 @@ while learner.next():
log("epoch {}".format(learner.epoch), log("epoch {}".format(learner.epoch),
fmt.format(learner.rate, avg_loss, avg_mloss * 100)) fmt.format(learner.rate, avg_loss, avg_mloss * 100))
batch_losses += losses logs.batch_losses += losses
batch_mlosses += mlosses logs.batch_mlosses += mlosses
if measure_every_epoch: if measure_every_epoch:
quiet = learner.epoch != learner.epochs quiet = learner.epoch != learner.epochs
measure_error(quiet=quiet) measure_error(quiet=quiet)
logs.learning_rate.append(optim.alpha)
if getattr(optim, 'mu', None):
logs.momentum.append(optim.mu)
if not measure_every_epoch: if not measure_every_epoch:
measure_error() measure_error()
@ -218,12 +230,8 @@ if save_fn is not None:
if log_fn: if log_fn:
log('saving losses', log_fn) log('saving losses', log_fn)
np.savez_compressed(log_fn, kwargs = dict()
batch_losses =np.array(batch_losses, dtype=_f), for k, v in logs.items():
batch_mlosses=np.array(batch_mlosses, dtype=_f), if len(v) > 0:
train_losses =np.array(train_losses, dtype=_f), kwargs[k] = np.array(v, dtype=_f)
train_mlosses=np.array(train_mlosses, dtype=_f), np.savez_compressed(log_fn, **kwargs)
valid_losses =np.array(valid_losses, dtype=_f),
valid_mlosses=np.array(valid_mlosses, dtype=_f),
train_confid =np.array(train_confid, dtype=_f),
valid_confid =np.array(valid_confid, dtype=_f))