diff --git a/optim_nn_core.py b/optim_nn_core.py index f65b7d4..9f8f6ff 100644 --- a/optim_nn_core.py +++ b/optim_nn_core.py @@ -701,59 +701,58 @@ class Ritual: # i'm just making up names at this point self.bn = 0 self.model = model - def train_batched(self, inputs, outputs, batch_size, return_losses=False): - self.en += 1 - cumsum_loss = _0 + def train_batched(self, inputs, outputs, batch_size, + return_losses=False, test_only=False): + assert isinstance(return_losses, bool) or return_losses == 'both' + + if not test_only: + self.en += 1 + + cumsum_loss, cumsum_mloss = _0, _0 batch_count = inputs.shape[0] // batch_size - losses = [] + losses, mlosses = [], [] + assert inputs.shape[0] % batch_size == 0, \ "inputs is not evenly divisible by batch_size" # TODO: lift this restriction for b in range(batch_count): - self.bn += 1 + if not test_only: + self.bn += 1 + bi = b * batch_size batch_inputs = inputs[ bi:bi+batch_size] batch_outputs = outputs[bi:bi+batch_size] - if self.learner.per_batch: - self.learner.batch(b / batch_count) + if not test_only and self.learner.per_batch: + self.learner.batch(b / batch_count) predicted = self.learn(batch_inputs, batch_outputs) - self.update() + if not test_only: + self.update() - batch_loss = self.measure(predicted, batch_outputs) - if np.isnan(batch_loss): - raise Exception("nan") - cumsum_loss += batch_loss - if return_losses: + if return_losses == 'both': + batch_loss = self.loss.forward(predicted, batch_outputs) + if np.isnan(batch_loss): + raise Exception("nan") losses.append(batch_loss) - avg_loss = cumsum_loss / _f(batch_count) - if return_losses: - return avg_loss, losses - return avg_loss + cumsum_loss += batch_loss - def test_batched(self, inputs, outputs, batch_size, return_losses=False): - cumsum_loss = _0 - batch_count = inputs.shape[0] // batch_size - losses = [] - assert inputs.shape[0] % batch_size == 0, \ - "inputs is not evenly divisible by batch_size" # TODO: lift this restriction - for b in range(batch_count): - bi = b * batch_size - batch_inputs = inputs[ bi:bi+batch_size] - batch_outputs = outputs[bi:bi+batch_size] - - predicted = self.model.forward(batch_inputs) - - batch_loss = self.measure(predicted, batch_outputs) - if np.isnan(batch_loss): + batch_mloss = self.measure(predicted, batch_outputs) + if np.isnan(batch_mloss): raise Exception("nan") - cumsum_loss += batch_loss if return_losses: - losses.append(batch_loss) - avg_loss = cumsum_loss / _f(batch_count) - if return_losses: - return avg_loss, losses - return avg_loss + mlosses.append(batch_mloss) + cumsum_mloss += batch_mloss + + avg_mloss = cumsum_mloss / _f(batch_count) + if return_losses == 'both': + avg_loss = cumsum_loss / _f(batch_count) + return avg_loss, avg_mloss, losses, mlosses + elif return_losses: + return avg_mloss, mlosses + return avg_mloss + + def test_batched(self, *args, **kwargs): + return self.train_batched(*args, test_only=True, **kwargs) # Learners {{{1