From fd2fc4bd7605f712998b321aa5eaa8d9489e166a Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Tue, 5 Feb 2019 06:09:35 +0100 Subject: [PATCH] begin rewriting Ritual --- onn/ritual.py | 8 ++--- onn/ritual_base.py | 90 ++++++++++++++++++++++++++++++++++++++-------- onn/utility.py | 26 ++++++++++++++ 3 files changed, 106 insertions(+), 18 deletions(-) diff --git a/onn/ritual.py b/onn/ritual.py index eab76bb..e32f483 100644 --- a/onn/ritual.py +++ b/onn/ritual.py @@ -38,7 +38,7 @@ class StochMRitual(Ritual): self.W = np.copy(model.W) super().prepare(model) - def learn(self, inputs, outputs): + def _learn(self, inputs, outputs): # an experiment: # assert self.learner.rate < 10, self.learner.rate # self.gamma = 1 - 1/2**(1 - np.log10(self.learner.rate)) @@ -51,7 +51,7 @@ class StochMRitual(Ritual): self.model.W[:] = self.W return residual - def update(self): + def _update(self): super().update() f = 0.5 for layer in self.model.ordered_nodes: @@ -68,7 +68,7 @@ class NoisyRitual(Ritual): self.gradient_noise = _f(gradient_noise) super().__init__(learner) - def learn(self, inputs, outputs): + def _learn(self, inputs, outputs): # this is pretty crude if self.input_noise > 0: s = self.input_noise @@ -78,7 +78,7 @@ class NoisyRitual(Ritual): outputs = outputs + np.random.normal(0, s, size=outputs.shape) return super().learn(inputs, outputs) - def update(self): + def _update(self): # gradient noise paper: https://arxiv.org/abs/1511.06807 if self.gradient_noise > 0: size = len(self.model.dW) diff --git a/onn/ritual_base.py b/onn/ritual_base.py index 1cbd1eb..80c6366 100644 --- a/onn/ritual_base.py +++ b/onn/ritual_base.py @@ -1,7 +1,11 @@ import types import numpy as np +from collections import namedtuple from .float import _f, _0 +from .utility import batchize + +Losses = namedtuple("Losses", ["avg_loss", "avg_mloss", "losses", "mlosses"]) class Ritual: # i'm just making up names at this point. @@ -14,21 +18,49 @@ class Ritual: # i'm just making up names at this point. self.en = 0 self.bn = 0 - def learn(self, inputs, outputs): + def prepare(self, model): + self.en = 0 + self.bn = 0 + self.model = model + + def _learn(self, inputs, outputs): error, predicted = self.model.forward(inputs, outputs) self.model.backward(predicted, outputs) self.model.regulate() return error, predicted - def update(self): + def _update(self): optim = self.learner.optim optim.model = self.model optim.update(self.model.dW, self.model.W) - def prepare(self, model): - self.en = 0 - self.bn = 0 - self.model = model + def _measure(self, predicted, outputs): + loss = self.model.loss.forward(predicted, outputs) + if np.isnan(loss): + raise Exception("nan") + self.losses.append(loss) + self.cumsum_loss += loss + + mloss = self.model.mloss.forward(predicted, outputs) + if np.isnan(mloss): + raise Exception("nan") + self.mlosses.append(mloss) + self.cumsum_mloss += mloss + + def _train_batch_new(self, inputs, outputs, b, batch_count): + if self.learner.per_batch: + self.learner.batch(b / batch_count) + + error, predicted = self.model.forward(inputs, outputs) + error += self.model.regulate_forward() + self.model.backward(predicted, outputs) + self.model.regulate() + + optim = self.learner.optim + optim.model = self.model + optim.update(self.model.dW, self.model.W) + + return predicted def _train_batch(self, batch_inputs, batch_outputs, b, batch_count, test_only=False, loss_logging=False, mloss_logging=True): @@ -38,9 +70,9 @@ class Ritual: # i'm just making up names at this point. if test_only: predicted = self.model.evaluate(batch_inputs, deterministic=True) else: - error, predicted = self.learn(batch_inputs, batch_outputs) + error, predicted = self._learn(batch_inputs, batch_outputs) self.model.regulate_forward() - self.update() + self._update() if loss_logging: batch_loss = self.model.loss.forward(predicted, batch_outputs) @@ -57,6 +89,24 @@ class Ritual: # i'm just making up names at this point. self.mlosses.append(batch_mloss) self.cumsum_mloss += batch_mloss + def train(self, batch_gen, batch_count, clear_grad=True): + assert self.model is not None, "call prepare(model) before training" + self.en += 1 + + self.cumsum_loss, self.cumsum_mloss = _0, _0 + self.losses, self.mlosses = [], [] + + for b, (inputs, outputs) in enumerate(batch_gen): + self.bn += 1 + if clear_grad: + self.model.clear_grad() + predicted = self._train_batch_new(inputs, outputs, b, batch_count) + self._measure(predicted, outputs) + + avg_mloss = self.cumsum_mloss / _f(batch_count) + avg_loss = self.cumsum_loss / _f(batch_count) + return Losses(avg_loss, avg_mloss, self.losses, self.mlosses) + def train_batched(self, inputs_or_generator, outputs_or_batch_count, batch_size=None, return_losses=False, test_only=False, shuffle=True, @@ -128,10 +178,22 @@ class Ritual: # i'm just making up names at this point. return avg_mloss, self.mlosses return avg_mloss - def test_batched(self, inputs, outputs, *args, **kwargs): - return self.train_batched(inputs, outputs, *args, - test_only=True, **kwargs) + def test_batched(self, inputs, outputs, batch_size=None): + assert self.model is not None, "call prepare(model) before testing" - def train_batched_gen(self, generator, batch_count, *args, **kwargs): - return self.train_batched(generator, batch_count, *args, - shuffle=False, **kwargs) + if batch_size is None: + batch_size = len(inputs) + + self.cumsum_loss, self.cumsum_mloss = _0, _0 + self.losses, self.mlosses = [], [] + + batch_gen, batch_count = batchize(inputs, outputs, batch_size, + shuffle=False) + + for inputs, outputs in batch_gen: + predicted = self.model.evaluate(inputs) + self._measure(predicted, outputs) + + avg_mloss = self.cumsum_mloss / _f(batch_count) + avg_loss = self.cumsum_loss / _f(batch_count) + return Losses(avg_loss, avg_mloss, self.losses, self.mlosses) diff --git a/onn/utility.py b/onn/utility.py index 3ef99d2..06ab2f9 100644 --- a/onn/utility.py +++ b/onn/utility.py @@ -37,6 +37,32 @@ def onehot(y): return Y +def batchize(inputs, outputs, batch_size, shuffle=True): + batch_count = np.ceil(len(inputs) / batch_size).astype(int) + + if shuffle: + def gen(): + indices = np.arange(len(inputs)) + np.random.shuffle(indices) + + for b in range(batch_count): + bi = b * batch_size + batch_indices = indices[bi:bi + batch_size] + batch_inputs = inputs[batch_indices] + batch_outputs = outputs[batch_indices] + yield batch_inputs, batch_outputs + + else: + def gen(): + for b in range(batch_count): + bi = b * batch_size + batch_inputs = inputs[bi:bi + batch_size] + batch_outputs = outputs[bi:bi + batch_size] + yield batch_inputs, batch_outputs + + return gen(), batch_count + + # more _log_was_update = False