easier logging of losses etc.
This commit is contained in:
parent
1b1184480a
commit
d8bf6d1c5b
1 changed files with 30 additions and 22 deletions
52
onn_mnist.py
52
onn_mnist.py
|
@ -2,6 +2,7 @@
|
|||
|
||||
from onn import *
|
||||
from onn_core import _f
|
||||
from dotmap import DotMap
|
||||
|
||||
#np.random.seed(42069)
|
||||
|
||||
|
@ -155,11 +156,18 @@ log('parameters', model.param_count)
|
|||
|
||||
ritual.prepare(model)
|
||||
|
||||
batch_losses, batch_mlosses = [], []
|
||||
train_losses, train_mlosses = [], []
|
||||
valid_losses, valid_mlosses = [], []
|
||||
|
||||
train_confid, valid_confid = [], []
|
||||
logs = DotMap(
|
||||
batch_losses = [],
|
||||
batch_mlosses = [],
|
||||
train_losses = [],
|
||||
train_mlosses = [],
|
||||
valid_losses = [],
|
||||
valid_mlosses = [],
|
||||
train_confid = [],
|
||||
valid_confid = [],
|
||||
learning_rate = [],
|
||||
momentum = [],
|
||||
)
|
||||
|
||||
def measure_error(quiet=False):
|
||||
def print_error(name, inputs, outputs, comparison=None):
|
||||
|
@ -177,13 +185,13 @@ def measure_error(quiet=False):
|
|||
return loss, mloss, confid
|
||||
|
||||
loss, mloss, confid = print_error("train", inputs, outputs)
|
||||
train_losses.append(loss)
|
||||
train_mlosses.append(mloss)
|
||||
train_confid.append(confid)
|
||||
logs.train_losses.append(loss)
|
||||
logs.train_mlosses.append(mloss)
|
||||
#logs.train_confid.append(confid)
|
||||
loss, mloss, confid = print_error("valid", valid_inputs, valid_outputs)
|
||||
valid_losses.append(loss)
|
||||
valid_mlosses.append(mloss)
|
||||
valid_confid.append(confid)
|
||||
logs.valid_losses.append(loss)
|
||||
logs.valid_mlosses.append(mloss)
|
||||
#logs.valid_confid.append(confid)
|
||||
|
||||
measure_error()
|
||||
|
||||
|
@ -202,13 +210,17 @@ while learner.next():
|
|||
log("epoch {}".format(learner.epoch),
|
||||
fmt.format(learner.rate, avg_loss, avg_mloss * 100))
|
||||
|
||||
batch_losses += losses
|
||||
batch_mlosses += mlosses
|
||||
logs.batch_losses += losses
|
||||
logs.batch_mlosses += mlosses
|
||||
|
||||
if measure_every_epoch:
|
||||
quiet = learner.epoch != learner.epochs
|
||||
measure_error(quiet=quiet)
|
||||
|
||||
logs.learning_rate.append(optim.alpha)
|
||||
if getattr(optim, 'mu', None):
|
||||
logs.momentum.append(optim.mu)
|
||||
|
||||
if not measure_every_epoch:
|
||||
measure_error()
|
||||
|
||||
|
@ -218,12 +230,8 @@ if save_fn is not None:
|
|||
|
||||
if log_fn:
|
||||
log('saving losses', log_fn)
|
||||
np.savez_compressed(log_fn,
|
||||
batch_losses =np.array(batch_losses, dtype=_f),
|
||||
batch_mlosses=np.array(batch_mlosses, dtype=_f),
|
||||
train_losses =np.array(train_losses, dtype=_f),
|
||||
train_mlosses=np.array(train_mlosses, dtype=_f),
|
||||
valid_losses =np.array(valid_losses, dtype=_f),
|
||||
valid_mlosses=np.array(valid_mlosses, dtype=_f),
|
||||
train_confid =np.array(train_confid, dtype=_f),
|
||||
valid_confid =np.array(valid_confid, dtype=_f))
|
||||
kwargs = dict()
|
||||
for k, v in logs.items():
|
||||
if len(v) > 0:
|
||||
kwargs[k] = np.array(v, dtype=_f)
|
||||
np.savez_compressed(log_fn, **kwargs)
|
||||
|
|
Loading…
Reference in a new issue