clean up Ritual a little
This commit is contained in:
parent
69e6ec3fad
commit
e2530c17e5
1 changed files with 10 additions and 4 deletions
|
@ -26,16 +26,22 @@ class Ritual: # i'm just making up names at this point.
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def _learn(self, inputs, outputs):
|
def _learn(self, inputs, outputs):
|
||||||
|
"""deprecated"""
|
||||||
error, predicted = self.model.forward(inputs, outputs)
|
error, predicted = self.model.forward(inputs, outputs)
|
||||||
self.model.backward(predicted, outputs)
|
self.model.backward(predicted, outputs)
|
||||||
self.model.regulate()
|
self.model.regulate()
|
||||||
return error, predicted
|
return error, predicted
|
||||||
|
|
||||||
def _update(self):
|
def _update(self):
|
||||||
|
"""deprecated"""
|
||||||
optim = self.learner.optim
|
optim = self.learner.optim
|
||||||
optim.model = self.model
|
optim.model = self.model
|
||||||
optim.update(self.model.dW, self.model.W)
|
optim.update(self.model.dW, self.model.W)
|
||||||
|
|
||||||
|
def _clear_measurements(self):
|
||||||
|
self.cumsum_loss, self.cumsum_mloss = _0, _0
|
||||||
|
self.losses, self.mlosses = [], []
|
||||||
|
|
||||||
def _measure(self, predicted, outputs):
|
def _measure(self, predicted, outputs):
|
||||||
loss = self.model.loss.forward(predicted, outputs)
|
loss = self.model.loss.forward(predicted, outputs)
|
||||||
if np.isnan(loss):
|
if np.isnan(loss):
|
||||||
|
@ -66,6 +72,7 @@ class Ritual: # i'm just making up names at this point.
|
||||||
|
|
||||||
def _train_batch(self, batch_inputs, batch_outputs, b, batch_count,
|
def _train_batch(self, batch_inputs, batch_outputs, b, batch_count,
|
||||||
test_only=False, loss_logging=False, mloss_logging=True):
|
test_only=False, loss_logging=False, mloss_logging=True):
|
||||||
|
"""deprecated"""
|
||||||
if not test_only and self.learner.per_batch:
|
if not test_only and self.learner.per_batch:
|
||||||
self.learner.batch(b / batch_count)
|
self.learner.batch(b / batch_count)
|
||||||
|
|
||||||
|
@ -95,8 +102,7 @@ class Ritual: # i'm just making up names at this point.
|
||||||
assert self.model is not None, "call prepare(model) before training"
|
assert self.model is not None, "call prepare(model) before training"
|
||||||
self.en += 1
|
self.en += 1
|
||||||
|
|
||||||
self.cumsum_loss, self.cumsum_mloss = _0, _0
|
self._clear_measurements()
|
||||||
self.losses, self.mlosses = [], []
|
|
||||||
|
|
||||||
for b, (inputs, outputs) in enumerate(batch_gen):
|
for b, (inputs, outputs) in enumerate(batch_gen):
|
||||||
self.bn += 1
|
self.bn += 1
|
||||||
|
@ -113,6 +119,7 @@ class Ritual: # i'm just making up names at this point.
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
return_losses=False, test_only=False, shuffle=True,
|
return_losses=False, test_only=False, shuffle=True,
|
||||||
clear_grad=True):
|
clear_grad=True):
|
||||||
|
"""deprecated"""
|
||||||
assert isinstance(return_losses, bool) or return_losses == 'both'
|
assert isinstance(return_losses, bool) or return_losses == 'both'
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
|
|
||||||
|
@ -137,8 +144,7 @@ class Ritual: # i'm just making up names at this point.
|
||||||
inputs = inputs[indices]
|
inputs = inputs[indices]
|
||||||
outputs = outputs[indices]
|
outputs = outputs[indices]
|
||||||
|
|
||||||
self.cumsum_loss, self.cumsum_mloss = _0, _0
|
self._clear_measurements()
|
||||||
self.losses, self.mlosses = [], []
|
|
||||||
|
|
||||||
if not gen:
|
if not gen:
|
||||||
batch_count = inputs.shape[0] // batch_size
|
batch_count = inputs.shape[0] // batch_size
|
||||||
|
|
Loading…
Add table
Reference in a new issue