add generator-based training method

This commit is contained in:
Connor Olding 2017-03-22 18:41:50 +00:00
parent 53a7d92288
commit 19a9583da9

View file

@ -708,6 +708,63 @@ class Ritual: # i'm just making up names at this point
self.bn = 0 self.bn = 0
self.model = model self.model = model
def train_batched_gen(self, generator, batch_count,
return_losses=False, test_only=False):
assert isinstance(return_losses, bool) or return_losses == 'both'
if not test_only:
self.en += 1
cumsum_loss, cumsum_mloss = _0, _0
losses, mlosses = [], []
prev_batch_size = None
for b in range(batch_count):
if not test_only:
self.bn += 1
# TODO: pass a GeneratorData object containing en, bn, ritual/model fields.
# ...is there a pythonic way of doing that?
batch_inputs, batch_outputs = next(generator)
batch_size = batch_inputs.shape[0]
assert batch_size == prev_batch_size or prev_batch_size is None, \
"non-constant batch size (got {} expected {})".format(
batch_size, prev_batch_size) # TODO: lift this restriction
prev_batch_size = batch_size
if not test_only and self.learner.per_batch:
self.learner.batch(b / batch_count)
if test_only:
predicted = self.model.forward(batch_inputs)
else:
predicted = self.learn(batch_inputs, batch_outputs)
self.update()
if return_losses == 'both':
batch_loss = self.loss.forward(predicted, batch_outputs)
if np.isnan(batch_loss):
raise Exception("nan")
losses.append(batch_loss)
cumsum_loss += batch_loss
batch_mloss = self.measure(predicted, batch_outputs)
if np.isnan(batch_mloss):
raise Exception("nan")
if return_losses:
mlosses.append(batch_mloss)
cumsum_mloss += batch_mloss
avg_mloss = cumsum_mloss / _f(batch_count)
if return_losses == 'both':
avg_loss = cumsum_loss / _f(batch_count)
return avg_loss, avg_mloss, losses, mlosses
elif return_losses:
return avg_mloss, mlosses
return avg_mloss
def train_batched(self, inputs, outputs, batch_size, def train_batched(self, inputs, outputs, batch_size,
return_losses=False, test_only=False): return_losses=False, test_only=False):
assert isinstance(return_losses, bool) or return_losses == 'both' assert isinstance(return_losses, bool) or return_losses == 'both'