begin rewriting Ritual

This commit is contained in:
Connor Olding 2019-02-05 06:09:35 +01:00
parent 2c921d34c2
commit fd2fc4bd76
3 changed files with 106 additions and 18 deletions

View File

@ -38,7 +38,7 @@ class StochMRitual(Ritual):
self.W = np.copy(model.W)
super().prepare(model)
def learn(self, inputs, outputs):
def _learn(self, inputs, outputs):
# an experiment:
# assert self.learner.rate < 10, self.learner.rate
# self.gamma = 1 - 1/2**(1 - np.log10(self.learner.rate))
@ -51,7 +51,7 @@ class StochMRitual(Ritual):
self.model.W[:] = self.W
return residual
def update(self):
def _update(self):
super().update()
f = 0.5
for layer in self.model.ordered_nodes:
@ -68,7 +68,7 @@ class NoisyRitual(Ritual):
self.gradient_noise = _f(gradient_noise)
super().__init__(learner)
def learn(self, inputs, outputs):
def _learn(self, inputs, outputs):
# this is pretty crude
if self.input_noise > 0:
s = self.input_noise
@ -78,7 +78,7 @@ class NoisyRitual(Ritual):
outputs = outputs + np.random.normal(0, s, size=outputs.shape)
return super().learn(inputs, outputs)
def update(self):
def _update(self):
# gradient noise paper: https://arxiv.org/abs/1511.06807
if self.gradient_noise > 0:
size = len(self.model.dW)

View File

@ -1,7 +1,11 @@
import types
import numpy as np
from collections import namedtuple
from .float import _f, _0
from .utility import batchize
Losses = namedtuple("Losses", ["avg_loss", "avg_mloss", "losses", "mlosses"])
class Ritual: # i'm just making up names at this point.
@ -14,21 +18,49 @@ class Ritual: # i'm just making up names at this point.
self.en = 0
self.bn = 0
def learn(self, inputs, outputs):
def prepare(self, model):
self.en = 0
self.bn = 0
self.model = model
def _learn(self, inputs, outputs):
error, predicted = self.model.forward(inputs, outputs)
self.model.backward(predicted, outputs)
self.model.regulate()
return error, predicted
def update(self):
def _update(self):
optim = self.learner.optim
optim.model = self.model
optim.update(self.model.dW, self.model.W)
def prepare(self, model):
self.en = 0
self.bn = 0
self.model = model
def _measure(self, predicted, outputs):
loss = self.model.loss.forward(predicted, outputs)
if np.isnan(loss):
raise Exception("nan")
self.losses.append(loss)
self.cumsum_loss += loss
mloss = self.model.mloss.forward(predicted, outputs)
if np.isnan(mloss):
raise Exception("nan")
self.mlosses.append(mloss)
self.cumsum_mloss += mloss
def _train_batch_new(self, inputs, outputs, b, batch_count):
if self.learner.per_batch:
self.learner.batch(b / batch_count)
error, predicted = self.model.forward(inputs, outputs)
error += self.model.regulate_forward()
self.model.backward(predicted, outputs)
self.model.regulate()
optim = self.learner.optim
optim.model = self.model
optim.update(self.model.dW, self.model.W)
return predicted
def _train_batch(self, batch_inputs, batch_outputs, b, batch_count,
test_only=False, loss_logging=False, mloss_logging=True):
@ -38,9 +70,9 @@ class Ritual: # i'm just making up names at this point.
if test_only:
predicted = self.model.evaluate(batch_inputs, deterministic=True)
else:
error, predicted = self.learn(batch_inputs, batch_outputs)
error, predicted = self._learn(batch_inputs, batch_outputs)
self.model.regulate_forward()
self.update()
self._update()
if loss_logging:
batch_loss = self.model.loss.forward(predicted, batch_outputs)
@ -57,6 +89,24 @@ class Ritual: # i'm just making up names at this point.
self.mlosses.append(batch_mloss)
self.cumsum_mloss += batch_mloss
def train(self, batch_gen, batch_count, clear_grad=True):
assert self.model is not None, "call prepare(model) before training"
self.en += 1
self.cumsum_loss, self.cumsum_mloss = _0, _0
self.losses, self.mlosses = [], []
for b, (inputs, outputs) in enumerate(batch_gen):
self.bn += 1
if clear_grad:
self.model.clear_grad()
predicted = self._train_batch_new(inputs, outputs, b, batch_count)
self._measure(predicted, outputs)
avg_mloss = self.cumsum_mloss / _f(batch_count)
avg_loss = self.cumsum_loss / _f(batch_count)
return Losses(avg_loss, avg_mloss, self.losses, self.mlosses)
def train_batched(self, inputs_or_generator, outputs_or_batch_count,
batch_size=None,
return_losses=False, test_only=False, shuffle=True,
@ -128,10 +178,22 @@ class Ritual: # i'm just making up names at this point.
return avg_mloss, self.mlosses
return avg_mloss
def test_batched(self, inputs, outputs, *args, **kwargs):
return self.train_batched(inputs, outputs, *args,
test_only=True, **kwargs)
def test_batched(self, inputs, outputs, batch_size=None):
assert self.model is not None, "call prepare(model) before testing"
def train_batched_gen(self, generator, batch_count, *args, **kwargs):
return self.train_batched(generator, batch_count, *args,
shuffle=False, **kwargs)
if batch_size is None:
batch_size = len(inputs)
self.cumsum_loss, self.cumsum_mloss = _0, _0
self.losses, self.mlosses = [], []
batch_gen, batch_count = batchize(inputs, outputs, batch_size,
shuffle=False)
for inputs, outputs in batch_gen:
predicted = self.model.evaluate(inputs)
self._measure(predicted, outputs)
avg_mloss = self.cumsum_mloss / _f(batch_count)
avg_loss = self.cumsum_loss / _f(batch_count)
return Losses(avg_loss, avg_mloss, self.losses, self.mlosses)

View File

@ -37,6 +37,32 @@ def onehot(y):
return Y
def batchize(inputs, outputs, batch_size, shuffle=True):
batch_count = np.ceil(len(inputs) / batch_size).astype(int)
if shuffle:
def gen():
indices = np.arange(len(inputs))
np.random.shuffle(indices)
for b in range(batch_count):
bi = b * batch_size
batch_indices = indices[bi:bi + batch_size]
batch_inputs = inputs[batch_indices]
batch_outputs = outputs[batch_indices]
yield batch_inputs, batch_outputs
else:
def gen():
for b in range(batch_count):
bi = b * batch_size
batch_inputs = inputs[bi:bi + batch_size]
batch_outputs = outputs[bi:bi + batch_size]
yield batch_inputs, batch_outputs
return gen(), batch_count
# more
_log_was_update = False