This commit is contained in:
Connor Olding 2017-02-27 22:48:49 +00:00
parent 547017a6fc
commit a1b1daf3bf

View file

@ -701,59 +701,58 @@ class Ritual: # i'm just making up names at this point
self.bn = 0
self.model = model
def train_batched(self, inputs, outputs, batch_size, return_losses=False):
self.en += 1
cumsum_loss = _0
def train_batched(self, inputs, outputs, batch_size,
return_losses=False, test_only=False):
assert isinstance(return_losses, bool) or return_losses == 'both'
if not test_only:
self.en += 1
cumsum_loss, cumsum_mloss = _0, _0
batch_count = inputs.shape[0] // batch_size
losses = []
losses, mlosses = [], []
assert inputs.shape[0] % batch_size == 0, \
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
for b in range(batch_count):
self.bn += 1
if not test_only:
self.bn += 1
bi = b * batch_size
batch_inputs = inputs[ bi:bi+batch_size]
batch_outputs = outputs[bi:bi+batch_size]
if self.learner.per_batch:
self.learner.batch(b / batch_count)
if not test_only and self.learner.per_batch:
self.learner.batch(b / batch_count)
predicted = self.learn(batch_inputs, batch_outputs)
self.update()
if not test_only:
self.update()
batch_loss = self.measure(predicted, batch_outputs)
if np.isnan(batch_loss):
raise Exception("nan")
cumsum_loss += batch_loss
if return_losses:
if return_losses == 'both':
batch_loss = self.loss.forward(predicted, batch_outputs)
if np.isnan(batch_loss):
raise Exception("nan")
losses.append(batch_loss)
avg_loss = cumsum_loss / _f(batch_count)
if return_losses:
return avg_loss, losses
return avg_loss
cumsum_loss += batch_loss
def test_batched(self, inputs, outputs, batch_size, return_losses=False):
cumsum_loss = _0
batch_count = inputs.shape[0] // batch_size
losses = []
assert inputs.shape[0] % batch_size == 0, \
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
for b in range(batch_count):
bi = b * batch_size
batch_inputs = inputs[ bi:bi+batch_size]
batch_outputs = outputs[bi:bi+batch_size]
predicted = self.model.forward(batch_inputs)
batch_loss = self.measure(predicted, batch_outputs)
if np.isnan(batch_loss):
batch_mloss = self.measure(predicted, batch_outputs)
if np.isnan(batch_mloss):
raise Exception("nan")
cumsum_loss += batch_loss
if return_losses:
losses.append(batch_loss)
avg_loss = cumsum_loss / _f(batch_count)
if return_losses:
return avg_loss, losses
return avg_loss
mlosses.append(batch_mloss)
cumsum_mloss += batch_mloss
avg_mloss = cumsum_mloss / _f(batch_count)
if return_losses == 'both':
avg_loss = cumsum_loss / _f(batch_count)
return avg_loss, avg_mloss, losses, mlosses
elif return_losses:
return avg_mloss, mlosses
return avg_mloss
def test_batched(self, *args, **kwargs):
return self.train_batched(*args, test_only=True, **kwargs)
# Learners {{{1