diff --git a/onn_mnist.py b/onn_mnist.py index e462c9e..24eae15 100755 --- a/onn_mnist.py +++ b/onn_mnist.py @@ -2,6 +2,7 @@ from onn import * from onn_core import _f +from dotmap import DotMap #np.random.seed(42069) @@ -155,11 +156,18 @@ log('parameters', model.param_count) ritual.prepare(model) -batch_losses, batch_mlosses = [], [] -train_losses, train_mlosses = [], [] -valid_losses, valid_mlosses = [], [] - -train_confid, valid_confid = [], [] +logs = DotMap( + batch_losses = [], + batch_mlosses = [], + train_losses = [], + train_mlosses = [], + valid_losses = [], + valid_mlosses = [], + train_confid = [], + valid_confid = [], + learning_rate = [], + momentum = [], +) def measure_error(quiet=False): def print_error(name, inputs, outputs, comparison=None): @@ -177,13 +185,13 @@ def measure_error(quiet=False): return loss, mloss, confid loss, mloss, confid = print_error("train", inputs, outputs) - train_losses.append(loss) - train_mlosses.append(mloss) - train_confid.append(confid) + logs.train_losses.append(loss) + logs.train_mlosses.append(mloss) + #logs.train_confid.append(confid) loss, mloss, confid = print_error("valid", valid_inputs, valid_outputs) - valid_losses.append(loss) - valid_mlosses.append(mloss) - valid_confid.append(confid) + logs.valid_losses.append(loss) + logs.valid_mlosses.append(mloss) + #logs.valid_confid.append(confid) measure_error() @@ -202,13 +210,17 @@ while learner.next(): log("epoch {}".format(learner.epoch), fmt.format(learner.rate, avg_loss, avg_mloss * 100)) - batch_losses += losses - batch_mlosses += mlosses + logs.batch_losses += losses + logs.batch_mlosses += mlosses if measure_every_epoch: quiet = learner.epoch != learner.epochs measure_error(quiet=quiet) + logs.learning_rate.append(optim.alpha) + if getattr(optim, 'mu', None): + logs.momentum.append(optim.mu) + if not measure_every_epoch: measure_error() @@ -218,12 +230,8 @@ if save_fn is not None: if log_fn: log('saving losses', log_fn) - np.savez_compressed(log_fn, - batch_losses =np.array(batch_losses, dtype=_f), - batch_mlosses=np.array(batch_mlosses, dtype=_f), - train_losses =np.array(train_losses, dtype=_f), - train_mlosses=np.array(train_mlosses, dtype=_f), - valid_losses =np.array(valid_losses, dtype=_f), - valid_mlosses=np.array(valid_mlosses, dtype=_f), - train_confid =np.array(train_confid, dtype=_f), - valid_confid =np.array(valid_confid, dtype=_f)) + kwargs = dict() + for k, v in logs.items(): + if len(v) > 0: + kwargs[k] = np.array(v, dtype=_f) + np.savez_compressed(log_fn, **kwargs)