From e2530c17e5b68c00c3b4d9ea66661cef54642a90 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Mon, 18 Feb 2019 06:25:10 +0100 Subject: [PATCH] clean up Ritual a little --- onn/ritual_base.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/onn/ritual_base.py b/onn/ritual_base.py index d5c51fa..ab61ed5 100644 --- a/onn/ritual_base.py +++ b/onn/ritual_base.py @@ -26,16 +26,22 @@ class Ritual: # i'm just making up names at this point. self.model = model def _learn(self, inputs, outputs): + """deprecated""" error, predicted = self.model.forward(inputs, outputs) self.model.backward(predicted, outputs) self.model.regulate() return error, predicted def _update(self): + """deprecated""" optim = self.learner.optim optim.model = self.model optim.update(self.model.dW, self.model.W) + def _clear_measurements(self): + self.cumsum_loss, self.cumsum_mloss = _0, _0 + self.losses, self.mlosses = [], [] + def _measure(self, predicted, outputs): loss = self.model.loss.forward(predicted, outputs) if np.isnan(loss): @@ -66,6 +72,7 @@ class Ritual: # i'm just making up names at this point. def _train_batch(self, batch_inputs, batch_outputs, b, batch_count, test_only=False, loss_logging=False, mloss_logging=True): + """deprecated""" if not test_only and self.learner.per_batch: self.learner.batch(b / batch_count) @@ -95,8 +102,7 @@ class Ritual: # i'm just making up names at this point. assert self.model is not None, "call prepare(model) before training" self.en += 1 - self.cumsum_loss, self.cumsum_mloss = _0, _0 - self.losses, self.mlosses = [], [] + self._clear_measurements() for b, (inputs, outputs) in enumerate(batch_gen): self.bn += 1 @@ -113,6 +119,7 @@ class Ritual: # i'm just making up names at this point. batch_size=None, return_losses=False, test_only=False, shuffle=True, clear_grad=True): + """deprecated""" assert isinstance(return_losses, bool) or return_losses == 'both' assert self.model is not None @@ -137,8 +144,7 @@ class Ritual: # i'm just making up names at this point. inputs = inputs[indices] outputs = outputs[indices] - self.cumsum_loss, self.cumsum_mloss = _0, _0 - self.losses, self.mlosses = [], [] + self._clear_measurements() if not gen: batch_count = inputs.shape[0] // batch_size