From 19a9583da977328ca9d6644bc78a7437a607b2b6 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Wed, 22 Mar 2017 18:41:50 +0000 Subject: [PATCH] add generator-based training method --- optim_nn_core.py | 57 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/optim_nn_core.py b/optim_nn_core.py index e6b3599..9361f20 100644 --- a/optim_nn_core.py +++ b/optim_nn_core.py @@ -708,6 +708,63 @@ class Ritual: # i'm just making up names at this point self.bn = 0 self.model = model + def train_batched_gen(self, generator, batch_count, + return_losses=False, test_only=False): + assert isinstance(return_losses, bool) or return_losses == 'both' + + if not test_only: + self.en += 1 + + cumsum_loss, cumsum_mloss = _0, _0 + losses, mlosses = [], [] + + prev_batch_size = None + + for b in range(batch_count): + if not test_only: + self.bn += 1 + + # TODO: pass a GeneratorData object containing en, bn, ritual/model fields. + # ...is there a pythonic way of doing that? + batch_inputs, batch_outputs = next(generator) + + batch_size = batch_inputs.shape[0] + assert batch_size == prev_batch_size or prev_batch_size is None, \ + "non-constant batch size (got {} expected {})".format( + batch_size, prev_batch_size) # TODO: lift this restriction + prev_batch_size = batch_size + + if not test_only and self.learner.per_batch: + self.learner.batch(b / batch_count) + + if test_only: + predicted = self.model.forward(batch_inputs) + else: + predicted = self.learn(batch_inputs, batch_outputs) + self.update() + + if return_losses == 'both': + batch_loss = self.loss.forward(predicted, batch_outputs) + if np.isnan(batch_loss): + raise Exception("nan") + losses.append(batch_loss) + cumsum_loss += batch_loss + + batch_mloss = self.measure(predicted, batch_outputs) + if np.isnan(batch_mloss): + raise Exception("nan") + if return_losses: + mlosses.append(batch_mloss) + cumsum_mloss += batch_mloss + + avg_mloss = cumsum_mloss / _f(batch_count) + if return_losses == 'both': + avg_loss = cumsum_loss / _f(batch_count) + return avg_loss, avg_mloss, losses, mlosses + elif return_losses: + return avg_mloss, mlosses + return avg_mloss + def train_batched(self, inputs, outputs, batch_size, return_losses=False, test_only=False): assert isinstance(return_losses, bool) or return_losses == 'both'