2018-01-21 14:04:25 -08:00
|
|
|
import types
|
|
|
|
import numpy as np
|
2019-02-04 21:09:35 -08:00
|
|
|
from collections import namedtuple
|
2018-01-21 14:04:25 -08:00
|
|
|
|
2018-03-17 06:09:15 -07:00
|
|
|
from .float import _f, _0
|
2019-02-04 21:09:35 -08:00
|
|
|
from .utility import batchize
|
|
|
|
|
|
|
|
Losses = namedtuple("Losses", ["avg_loss", "avg_mloss", "losses", "mlosses"])
|
2018-01-21 14:04:25 -08:00
|
|
|
|
2018-01-22 11:40:36 -08:00
|
|
|
|
|
|
|
class Ritual: # i'm just making up names at this point.
|
2019-02-05 13:16:46 -08:00
|
|
|
def __init__(self, learner=None, model=None):
|
2018-01-21 14:04:25 -08:00
|
|
|
self.learner = learner if learner is not None else Learner(Optimizer())
|
|
|
|
self.model = None
|
2019-02-05 13:16:46 -08:00
|
|
|
if model is not None:
|
|
|
|
self.prepare(model)
|
2018-01-21 14:04:25 -08:00
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.learner.reset(optim=True)
|
|
|
|
self.en = 0
|
|
|
|
self.bn = 0
|
|
|
|
|
2019-02-04 21:09:35 -08:00
|
|
|
def prepare(self, model):
|
|
|
|
self.en = 0
|
|
|
|
self.bn = 0
|
|
|
|
self.model = model
|
|
|
|
|
|
|
|
def _learn(self, inputs, outputs):
|
2018-01-21 14:04:25 -08:00
|
|
|
error, predicted = self.model.forward(inputs, outputs)
|
|
|
|
self.model.backward(predicted, outputs)
|
|
|
|
self.model.regulate()
|
|
|
|
return error, predicted
|
|
|
|
|
2019-02-04 21:09:35 -08:00
|
|
|
def _update(self):
|
2018-01-21 14:04:25 -08:00
|
|
|
optim = self.learner.optim
|
|
|
|
optim.model = self.model
|
|
|
|
optim.update(self.model.dW, self.model.W)
|
|
|
|
|
2019-02-04 21:09:35 -08:00
|
|
|
def _measure(self, predicted, outputs):
|
|
|
|
loss = self.model.loss.forward(predicted, outputs)
|
|
|
|
if np.isnan(loss):
|
|
|
|
raise Exception("nan")
|
|
|
|
self.losses.append(loss)
|
|
|
|
self.cumsum_loss += loss
|
|
|
|
|
|
|
|
mloss = self.model.mloss.forward(predicted, outputs)
|
|
|
|
if np.isnan(mloss):
|
|
|
|
raise Exception("nan")
|
|
|
|
self.mlosses.append(mloss)
|
|
|
|
self.cumsum_mloss += mloss
|
|
|
|
|
|
|
|
def _train_batch_new(self, inputs, outputs, b, batch_count):
|
|
|
|
if self.learner.per_batch:
|
|
|
|
self.learner.batch(b / batch_count)
|
|
|
|
|
|
|
|
error, predicted = self.model.forward(inputs, outputs)
|
|
|
|
error += self.model.regulate_forward()
|
|
|
|
self.model.backward(predicted, outputs)
|
|
|
|
self.model.regulate()
|
|
|
|
|
|
|
|
optim = self.learner.optim
|
|
|
|
optim.model = self.model
|
|
|
|
optim.update(self.model.dW, self.model.W)
|
|
|
|
|
|
|
|
return predicted
|
2018-01-21 14:04:25 -08:00
|
|
|
|
|
|
|
def _train_batch(self, batch_inputs, batch_outputs, b, batch_count,
|
|
|
|
test_only=False, loss_logging=False, mloss_logging=True):
|
|
|
|
if not test_only and self.learner.per_batch:
|
|
|
|
self.learner.batch(b / batch_count)
|
|
|
|
|
|
|
|
if test_only:
|
|
|
|
predicted = self.model.evaluate(batch_inputs, deterministic=True)
|
|
|
|
else:
|
2019-02-04 21:09:35 -08:00
|
|
|
error, predicted = self._learn(batch_inputs, batch_outputs)
|
2018-01-21 14:04:25 -08:00
|
|
|
self.model.regulate_forward()
|
2019-02-04 21:09:35 -08:00
|
|
|
self._update()
|
2018-01-21 14:04:25 -08:00
|
|
|
|
|
|
|
if loss_logging:
|
|
|
|
batch_loss = self.model.loss.forward(predicted, batch_outputs)
|
|
|
|
if np.isnan(batch_loss):
|
|
|
|
raise Exception("nan")
|
|
|
|
self.losses.append(batch_loss)
|
|
|
|
self.cumsum_loss += batch_loss
|
|
|
|
|
|
|
|
if mloss_logging:
|
|
|
|
# NOTE: this can use the non-deterministic predictions. fixme?
|
|
|
|
batch_mloss = self.model.mloss.forward(predicted, batch_outputs)
|
|
|
|
if np.isnan(batch_mloss):
|
|
|
|
raise Exception("nan")
|
|
|
|
self.mlosses.append(batch_mloss)
|
|
|
|
self.cumsum_mloss += batch_mloss
|
|
|
|
|
2019-02-04 21:09:35 -08:00
|
|
|
def train(self, batch_gen, batch_count, clear_grad=True):
|
|
|
|
assert self.model is not None, "call prepare(model) before training"
|
|
|
|
self.en += 1
|
|
|
|
|
|
|
|
self.cumsum_loss, self.cumsum_mloss = _0, _0
|
|
|
|
self.losses, self.mlosses = [], []
|
|
|
|
|
|
|
|
for b, (inputs, outputs) in enumerate(batch_gen):
|
|
|
|
self.bn += 1
|
|
|
|
if clear_grad:
|
|
|
|
self.model.clear_grad()
|
|
|
|
predicted = self._train_batch_new(inputs, outputs, b, batch_count)
|
|
|
|
self._measure(predicted, outputs)
|
|
|
|
|
|
|
|
avg_mloss = self.cumsum_mloss / _f(batch_count)
|
|
|
|
avg_loss = self.cumsum_loss / _f(batch_count)
|
|
|
|
return Losses(avg_loss, avg_mloss, self.losses, self.mlosses)
|
|
|
|
|
2018-01-21 14:04:25 -08:00
|
|
|
def train_batched(self, inputs_or_generator, outputs_or_batch_count,
|
|
|
|
batch_size=None,
|
|
|
|
return_losses=False, test_only=False, shuffle=True,
|
|
|
|
clear_grad=True):
|
|
|
|
assert isinstance(return_losses, bool) or return_losses == 'both'
|
|
|
|
assert self.model is not None
|
|
|
|
|
|
|
|
gen = isinstance(inputs_or_generator, types.GeneratorType)
|
|
|
|
if gen:
|
|
|
|
generator = inputs_or_generator
|
|
|
|
batch_count = outputs_or_batch_count
|
|
|
|
assert isinstance(batch_count, int), type(batch_count)
|
|
|
|
else:
|
|
|
|
inputs = inputs_or_generator
|
|
|
|
outputs = outputs_or_batch_count
|
|
|
|
|
|
|
|
if not test_only:
|
|
|
|
self.en += 1
|
|
|
|
|
|
|
|
if shuffle:
|
|
|
|
if gen:
|
2018-01-22 11:40:36 -08:00
|
|
|
raise Exception(
|
|
|
|
"shuffling is incompatibile with using a generator.")
|
2018-01-21 14:04:25 -08:00
|
|
|
indices = np.arange(inputs.shape[0])
|
|
|
|
np.random.shuffle(indices)
|
|
|
|
inputs = inputs[indices]
|
|
|
|
outputs = outputs[indices]
|
|
|
|
|
|
|
|
self.cumsum_loss, self.cumsum_mloss = _0, _0
|
|
|
|
self.losses, self.mlosses = [], []
|
|
|
|
|
|
|
|
if not gen:
|
|
|
|
batch_count = inputs.shape[0] // batch_size
|
|
|
|
# TODO: lift this restriction
|
|
|
|
assert inputs.shape[0] % batch_size == 0, \
|
2018-01-22 11:40:36 -08:00
|
|
|
"inputs is not evenly divisible by batch_size"
|
2018-01-21 14:04:25 -08:00
|
|
|
|
|
|
|
prev_batch_size = None
|
|
|
|
for b in range(batch_count):
|
|
|
|
if not test_only:
|
|
|
|
self.bn += 1
|
|
|
|
|
|
|
|
if gen:
|
|
|
|
batch_inputs, batch_outputs = next(generator)
|
|
|
|
batch_size = batch_inputs.shape[0]
|
|
|
|
# TODO: lift this restriction
|
2018-01-22 11:40:36 -08:00
|
|
|
fmt = "non-constant batch size (got {}, expected {})"
|
|
|
|
assert (batch_size == prev_batch_size
|
|
|
|
or prev_batch_size is None), \
|
|
|
|
fmt.format(batch_size, prev_batch_size)
|
2018-01-21 14:04:25 -08:00
|
|
|
else:
|
|
|
|
bi = b * batch_size
|
2018-01-22 11:40:36 -08:00
|
|
|
batch_inputs = inputs[bi:bi+batch_size]
|
2018-01-21 14:04:25 -08:00
|
|
|
batch_outputs = outputs[bi:bi+batch_size]
|
|
|
|
|
|
|
|
if clear_grad:
|
|
|
|
self.model.clear_grad()
|
|
|
|
self._train_batch(batch_inputs, batch_outputs, b, batch_count,
|
2018-01-22 11:40:36 -08:00
|
|
|
test_only, return_losses == 'both',
|
|
|
|
return_losses)
|
2018-01-21 14:04:25 -08:00
|
|
|
|
|
|
|
prev_batch_size = batch_size
|
|
|
|
|
|
|
|
avg_mloss = self.cumsum_mloss / _f(batch_count)
|
|
|
|
if return_losses == 'both':
|
|
|
|
avg_loss = self.cumsum_loss / _f(batch_count)
|
|
|
|
return avg_loss, avg_mloss, self.losses, self.mlosses
|
|
|
|
elif return_losses:
|
|
|
|
return avg_mloss, self.mlosses
|
|
|
|
return avg_mloss
|
|
|
|
|
2019-02-04 21:09:35 -08:00
|
|
|
def test_batched(self, inputs, outputs, batch_size=None):
|
|
|
|
assert self.model is not None, "call prepare(model) before testing"
|
|
|
|
|
|
|
|
if batch_size is None:
|
|
|
|
batch_size = len(inputs)
|
2018-01-21 14:04:25 -08:00
|
|
|
|
2019-02-04 21:09:35 -08:00
|
|
|
self.cumsum_loss, self.cumsum_mloss = _0, _0
|
|
|
|
self.losses, self.mlosses = [], []
|
|
|
|
|
|
|
|
batch_gen, batch_count = batchize(inputs, outputs, batch_size,
|
|
|
|
shuffle=False)
|
|
|
|
|
|
|
|
for inputs, outputs in batch_gen:
|
|
|
|
predicted = self.model.evaluate(inputs)
|
|
|
|
self._measure(predicted, outputs)
|
|
|
|
|
|
|
|
avg_mloss = self.cumsum_mloss / _f(batch_count)
|
|
|
|
avg_loss = self.cumsum_loss / _f(batch_count)
|
|
|
|
return Losses(avg_loss, avg_mloss, self.losses, self.mlosses)
|