clean up Ritual a little
This commit is contained in:
parent
69e6ec3fad
commit
e2530c17e5
1 changed files with 10 additions and 4 deletions
|
@ -26,16 +26,22 @@ class Ritual: # i'm just making up names at this point.
|
|||
self.model = model
|
||||
|
||||
def _learn(self, inputs, outputs):
|
||||
"""deprecated"""
|
||||
error, predicted = self.model.forward(inputs, outputs)
|
||||
self.model.backward(predicted, outputs)
|
||||
self.model.regulate()
|
||||
return error, predicted
|
||||
|
||||
def _update(self):
|
||||
"""deprecated"""
|
||||
optim = self.learner.optim
|
||||
optim.model = self.model
|
||||
optim.update(self.model.dW, self.model.W)
|
||||
|
||||
def _clear_measurements(self):
|
||||
self.cumsum_loss, self.cumsum_mloss = _0, _0
|
||||
self.losses, self.mlosses = [], []
|
||||
|
||||
def _measure(self, predicted, outputs):
|
||||
loss = self.model.loss.forward(predicted, outputs)
|
||||
if np.isnan(loss):
|
||||
|
@ -66,6 +72,7 @@ class Ritual: # i'm just making up names at this point.
|
|||
|
||||
def _train_batch(self, batch_inputs, batch_outputs, b, batch_count,
|
||||
test_only=False, loss_logging=False, mloss_logging=True):
|
||||
"""deprecated"""
|
||||
if not test_only and self.learner.per_batch:
|
||||
self.learner.batch(b / batch_count)
|
||||
|
||||
|
@ -95,8 +102,7 @@ class Ritual: # i'm just making up names at this point.
|
|||
assert self.model is not None, "call prepare(model) before training"
|
||||
self.en += 1
|
||||
|
||||
self.cumsum_loss, self.cumsum_mloss = _0, _0
|
||||
self.losses, self.mlosses = [], []
|
||||
self._clear_measurements()
|
||||
|
||||
for b, (inputs, outputs) in enumerate(batch_gen):
|
||||
self.bn += 1
|
||||
|
@ -113,6 +119,7 @@ class Ritual: # i'm just making up names at this point.
|
|||
batch_size=None,
|
||||
return_losses=False, test_only=False, shuffle=True,
|
||||
clear_grad=True):
|
||||
"""deprecated"""
|
||||
assert isinstance(return_losses, bool) or return_losses == 'both'
|
||||
assert self.model is not None
|
||||
|
||||
|
@ -137,8 +144,7 @@ class Ritual: # i'm just making up names at this point.
|
|||
inputs = inputs[indices]
|
||||
outputs = outputs[indices]
|
||||
|
||||
self.cumsum_loss, self.cumsum_mloss = _0, _0
|
||||
self.losses, self.mlosses = [], []
|
||||
self._clear_measurements()
|
||||
|
||||
if not gen:
|
||||
batch_count = inputs.shape[0] // batch_size
|
||||
|
|
Loading…
Reference in a new issue