clean up Ritual a little

This commit is contained in:
Connor Olding 2019-02-18 06:25:10 +01:00
parent 69e6ec3fad
commit e2530c17e5

View File

@ -26,16 +26,22 @@ class Ritual: # i'm just making up names at this point.
self.model = model
def _learn(self, inputs, outputs):
"""deprecated"""
error, predicted = self.model.forward(inputs, outputs)
self.model.backward(predicted, outputs)
self.model.regulate()
return error, predicted
def _update(self):
"""deprecated"""
optim = self.learner.optim
optim.model = self.model
optim.update(self.model.dW, self.model.W)
def _clear_measurements(self):
self.cumsum_loss, self.cumsum_mloss = _0, _0
self.losses, self.mlosses = [], []
def _measure(self, predicted, outputs):
loss = self.model.loss.forward(predicted, outputs)
if np.isnan(loss):
@ -66,6 +72,7 @@ class Ritual: # i'm just making up names at this point.
def _train_batch(self, batch_inputs, batch_outputs, b, batch_count,
test_only=False, loss_logging=False, mloss_logging=True):
"""deprecated"""
if not test_only and self.learner.per_batch:
self.learner.batch(b / batch_count)
@ -95,8 +102,7 @@ class Ritual: # i'm just making up names at this point.
assert self.model is not None, "call prepare(model) before training"
self.en += 1
self.cumsum_loss, self.cumsum_mloss = _0, _0
self.losses, self.mlosses = [], []
self._clear_measurements()
for b, (inputs, outputs) in enumerate(batch_gen):
self.bn += 1
@ -113,6 +119,7 @@ class Ritual: # i'm just making up names at this point.
batch_size=None,
return_losses=False, test_only=False, shuffle=True,
clear_grad=True):
"""deprecated"""
assert isinstance(return_losses, bool) or return_losses == 'both'
assert self.model is not None
@ -137,8 +144,7 @@ class Ritual: # i'm just making up names at this point.
inputs = inputs[indices]
outputs = outputs[indices]
self.cumsum_loss, self.cumsum_mloss = _0, _0
self.losses, self.mlosses = [], []
self._clear_measurements()
if not gen:
batch_count = inputs.shape[0] // batch_size