diff --git a/onn/ritual_base.py b/onn/ritual_base.py index ab61ed5..f60dfd1 100644 --- a/onn/ritual_base.py +++ b/onn/ritual_base.py @@ -59,13 +59,15 @@ class Ritual: # i'm just making up names at this point. if self.learner.per_batch: self.learner.batch(b / batch_count) - error, predicted = self.model.forward(inputs, outputs) - error += self.model.regulate_forward() + loss, predicted = self.model.forward(inputs, outputs) + reg_loss = self.model.regulate_forward() self.model.backward(predicted, outputs) self.model.regulate() optim = self.learner.optim - optim.model = self.model + optim.model = self.model # TODO: optim.inform(model=model) or something + optim.error = predicted - outputs # FIXME: temp + optim.loss = loss # FIXME: temp optim.update(self.model.dW, self.model.W) return predicted