This commit is contained in:
Connor Olding 2017-02-27 22:48:49 +00:00
parent 547017a6fc
commit a1b1daf3bf

View file

@ -701,59 +701,58 @@ class Ritual: # i'm just making up names at this point
self.bn = 0 self.bn = 0
self.model = model self.model = model
def train_batched(self, inputs, outputs, batch_size, return_losses=False): def train_batched(self, inputs, outputs, batch_size,
self.en += 1 return_losses=False, test_only=False):
cumsum_loss = _0 assert isinstance(return_losses, bool) or return_losses == 'both'
if not test_only:
self.en += 1
cumsum_loss, cumsum_mloss = _0, _0
batch_count = inputs.shape[0] // batch_size batch_count = inputs.shape[0] // batch_size
losses = [] losses, mlosses = [], []
assert inputs.shape[0] % batch_size == 0, \ assert inputs.shape[0] % batch_size == 0, \
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction "inputs is not evenly divisible by batch_size" # TODO: lift this restriction
for b in range(batch_count): for b in range(batch_count):
self.bn += 1 if not test_only:
self.bn += 1
bi = b * batch_size bi = b * batch_size
batch_inputs = inputs[ bi:bi+batch_size] batch_inputs = inputs[ bi:bi+batch_size]
batch_outputs = outputs[bi:bi+batch_size] batch_outputs = outputs[bi:bi+batch_size]
if self.learner.per_batch: if not test_only and self.learner.per_batch:
self.learner.batch(b / batch_count) self.learner.batch(b / batch_count)
predicted = self.learn(batch_inputs, batch_outputs) predicted = self.learn(batch_inputs, batch_outputs)
self.update() if not test_only:
self.update()
batch_loss = self.measure(predicted, batch_outputs) if return_losses == 'both':
if np.isnan(batch_loss): batch_loss = self.loss.forward(predicted, batch_outputs)
raise Exception("nan") if np.isnan(batch_loss):
cumsum_loss += batch_loss raise Exception("nan")
if return_losses:
losses.append(batch_loss) losses.append(batch_loss)
avg_loss = cumsum_loss / _f(batch_count) cumsum_loss += batch_loss
if return_losses:
return avg_loss, losses
return avg_loss
def test_batched(self, inputs, outputs, batch_size, return_losses=False): batch_mloss = self.measure(predicted, batch_outputs)
cumsum_loss = _0 if np.isnan(batch_mloss):
batch_count = inputs.shape[0] // batch_size
losses = []
assert inputs.shape[0] % batch_size == 0, \
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
for b in range(batch_count):
bi = b * batch_size
batch_inputs = inputs[ bi:bi+batch_size]
batch_outputs = outputs[bi:bi+batch_size]
predicted = self.model.forward(batch_inputs)
batch_loss = self.measure(predicted, batch_outputs)
if np.isnan(batch_loss):
raise Exception("nan") raise Exception("nan")
cumsum_loss += batch_loss
if return_losses: if return_losses:
losses.append(batch_loss) mlosses.append(batch_mloss)
avg_loss = cumsum_loss / _f(batch_count) cumsum_mloss += batch_mloss
if return_losses:
return avg_loss, losses avg_mloss = cumsum_mloss / _f(batch_count)
return avg_loss if return_losses == 'both':
avg_loss = cumsum_loss / _f(batch_count)
return avg_loss, avg_mloss, losses, mlosses
elif return_losses:
return avg_mloss, mlosses
return avg_mloss
def test_batched(self, *args, **kwargs):
return self.train_batched(*args, test_only=True, **kwargs)
# Learners {{{1 # Learners {{{1