.
This commit is contained in:
parent
547017a6fc
commit
a1b1daf3bf
1 changed files with 37 additions and 38 deletions
|
@ -701,59 +701,58 @@ class Ritual: # i'm just making up names at this point
|
||||||
self.bn = 0
|
self.bn = 0
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def train_batched(self, inputs, outputs, batch_size, return_losses=False):
|
def train_batched(self, inputs, outputs, batch_size,
|
||||||
self.en += 1
|
return_losses=False, test_only=False):
|
||||||
cumsum_loss = _0
|
assert isinstance(return_losses, bool) or return_losses == 'both'
|
||||||
|
|
||||||
|
if not test_only:
|
||||||
|
self.en += 1
|
||||||
|
|
||||||
|
cumsum_loss, cumsum_mloss = _0, _0
|
||||||
batch_count = inputs.shape[0] // batch_size
|
batch_count = inputs.shape[0] // batch_size
|
||||||
losses = []
|
losses, mlosses = [], []
|
||||||
|
|
||||||
assert inputs.shape[0] % batch_size == 0, \
|
assert inputs.shape[0] % batch_size == 0, \
|
||||||
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
|
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
|
||||||
for b in range(batch_count):
|
for b in range(batch_count):
|
||||||
self.bn += 1
|
if not test_only:
|
||||||
|
self.bn += 1
|
||||||
|
|
||||||
bi = b * batch_size
|
bi = b * batch_size
|
||||||
batch_inputs = inputs[ bi:bi+batch_size]
|
batch_inputs = inputs[ bi:bi+batch_size]
|
||||||
batch_outputs = outputs[bi:bi+batch_size]
|
batch_outputs = outputs[bi:bi+batch_size]
|
||||||
|
|
||||||
if self.learner.per_batch:
|
if not test_only and self.learner.per_batch:
|
||||||
self.learner.batch(b / batch_count)
|
self.learner.batch(b / batch_count)
|
||||||
|
|
||||||
predicted = self.learn(batch_inputs, batch_outputs)
|
predicted = self.learn(batch_inputs, batch_outputs)
|
||||||
self.update()
|
if not test_only:
|
||||||
|
self.update()
|
||||||
|
|
||||||
batch_loss = self.measure(predicted, batch_outputs)
|
if return_losses == 'both':
|
||||||
if np.isnan(batch_loss):
|
batch_loss = self.loss.forward(predicted, batch_outputs)
|
||||||
raise Exception("nan")
|
if np.isnan(batch_loss):
|
||||||
cumsum_loss += batch_loss
|
raise Exception("nan")
|
||||||
if return_losses:
|
|
||||||
losses.append(batch_loss)
|
losses.append(batch_loss)
|
||||||
avg_loss = cumsum_loss / _f(batch_count)
|
cumsum_loss += batch_loss
|
||||||
if return_losses:
|
|
||||||
return avg_loss, losses
|
|
||||||
return avg_loss
|
|
||||||
|
|
||||||
def test_batched(self, inputs, outputs, batch_size, return_losses=False):
|
batch_mloss = self.measure(predicted, batch_outputs)
|
||||||
cumsum_loss = _0
|
if np.isnan(batch_mloss):
|
||||||
batch_count = inputs.shape[0] // batch_size
|
|
||||||
losses = []
|
|
||||||
assert inputs.shape[0] % batch_size == 0, \
|
|
||||||
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
|
|
||||||
for b in range(batch_count):
|
|
||||||
bi = b * batch_size
|
|
||||||
batch_inputs = inputs[ bi:bi+batch_size]
|
|
||||||
batch_outputs = outputs[bi:bi+batch_size]
|
|
||||||
|
|
||||||
predicted = self.model.forward(batch_inputs)
|
|
||||||
|
|
||||||
batch_loss = self.measure(predicted, batch_outputs)
|
|
||||||
if np.isnan(batch_loss):
|
|
||||||
raise Exception("nan")
|
raise Exception("nan")
|
||||||
cumsum_loss += batch_loss
|
|
||||||
if return_losses:
|
if return_losses:
|
||||||
losses.append(batch_loss)
|
mlosses.append(batch_mloss)
|
||||||
avg_loss = cumsum_loss / _f(batch_count)
|
cumsum_mloss += batch_mloss
|
||||||
if return_losses:
|
|
||||||
return avg_loss, losses
|
avg_mloss = cumsum_mloss / _f(batch_count)
|
||||||
return avg_loss
|
if return_losses == 'both':
|
||||||
|
avg_loss = cumsum_loss / _f(batch_count)
|
||||||
|
return avg_loss, avg_mloss, losses, mlosses
|
||||||
|
elif return_losses:
|
||||||
|
return avg_mloss, mlosses
|
||||||
|
return avg_mloss
|
||||||
|
|
||||||
|
def test_batched(self, *args, **kwargs):
|
||||||
|
return self.train_batched(*args, test_only=True, **kwargs)
|
||||||
|
|
||||||
# Learners {{{1
|
# Learners {{{1
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue