optim/onn/learner.py

192 lines
5.7 KiB
Python
Raw Normal View History

from .float import _f, _1, _pi, _inv2
2018-01-21 14:04:25 -08:00
from .optimizer_base import *
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
class Learner:
per_batch = False
def __init__(self, optim, epochs=100, rate=None):
assert isinstance(optim, Optimizer)
self.optim = optim
2018-01-22 11:40:36 -08:00
self.start_rate = rate # None is okay; it'll use optim.lr instead.
2018-01-21 14:04:25 -08:00
self.epochs = int(epochs)
self.reset()
def reset(self, optim=False):
self.started = False
self.epoch = 0
if optim:
self.optim.reset()
@property
def epoch(self):
return self._epoch
@epoch.setter
def epoch(self, new_epoch):
self._epoch = int(new_epoch)
if 0 <= self.epoch <= self.epochs:
self.rate = self.rate_at(self._epoch)
@property
def rate(self):
return self.optim.lr
@rate.setter
def rate(self, new_rate):
self.optim.lr = new_rate
def rate_at(self, epoch):
if self.start_rate is None:
return self.optim.lr
return self.start_rate
def next(self):
# prepares the next epoch. returns whether or not to continue training.
if not self.started:
self.started = True
self.epoch += 1
if self.epoch > self.epochs:
return False
return True
2018-01-22 11:40:36 -08:00
def batch(self, progress): # TODO: rename
2018-01-21 14:04:25 -08:00
# interpolates rates between epochs.
# unlike epochs, we do not store batch number as a state.
# i.e. calling next() will not respect progress.
assert 0 <= progress <= 1
self.rate = self.rate_at(self._epoch + progress)
@property
def final_rate(self):
return self.rate_at(self.epochs - 1e-8)
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
class AnnealingLearner(Learner):
def __init__(self, optim, epochs=100, rate=None, halve_every=10):
self.halve_every = _f(halve_every)
self.anneal = _f(0.5**(1/self.halve_every))
super().__init__(optim, epochs, rate)
def rate_at(self, epoch):
return super().rate_at(epoch) * self.anneal**epoch
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
def cosmod(x):
# plot: https://www.desmos.com/calculator/hlgqmyswy2
return (_1 + np.cos((x % _1) * _pi)) * _inv2
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
class SGDR(Learner):
# Stochastic Gradient Descent with Restarts
# paper: https://arxiv.org/abs/1608.03983
# NOTE: this is missing a couple of the proposed features.
per_batch = True
def __init__(self, optim, epochs=100, rate=None,
restarts=0, restart_decay=0.5, callback=None,
expando=0):
self.restart_epochs = int(epochs)
self.decay = _f(restart_decay)
self.restarts = int(restarts)
self.restart_callback = callback
# TODO: rename expando to something not insane
self.expando = expando if expando is not None else lambda i: i
if type(self.expando) == int:
inc = self.expando
self.expando = lambda i: i * inc
self.splits = []
epochs = 0
for i in range(0, self.restarts + 1):
split = epochs + self.restart_epochs + int(self.expando(i))
self.splits.append(split)
epochs = split
super().__init__(optim, epochs, rate)
def split_num(self, epoch):
previous = [0] + self.splits
for i, split in enumerate(self.splits):
if epoch - 1 < split:
sub_epoch = epoch - previous[i]
next_restart = split - previous[i]
return i, sub_epoch, next_restart
raise Exception('this should never happen.')
def rate_at(self, epoch):
2018-01-22 11:40:36 -08:00
sr = self.start_rate
base_rate = sr if sr is not None else self.optim.lr
2018-01-21 14:04:25 -08:00
restart, sub_epoch, next_restart = self.split_num(max(1, epoch))
x = _f(sub_epoch - 1) / _f(next_restart)
return base_rate * self.decay**_f(restart) * cosmod(x)
def next(self):
if not super().next():
return False
restart, sub_epoch, next_restart = self.split_num(self.epoch)
if restart > 0 and sub_epoch == 1:
if self.restart_callback is not None:
self.restart_callback(restart)
return True
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
class TriangularCLR(Learner):
per_batch = True
def __init__(self, optim, epochs=400, upper_rate=None, lower_rate=0,
frequency=100, callback=None):
# NOTE: start_rate is treated as upper_rate
self.frequency = int(frequency)
assert self.frequency > 0
self.callback = callback
self.lower_rate = _f(lower_rate)
super().__init__(optim, epochs, upper_rate)
def _t(self, epoch):
# NOTE: this could probably be simplified
offset = self.frequency / 2
2018-01-22 11:40:36 -08:00
return np.abs(((epoch - 1 + offset) % self.frequency) - offset) \
/ offset
2018-01-21 14:04:25 -08:00
def rate_at(self, epoch):
2018-01-22 11:40:36 -08:00
sr = self.start_rate
lr = self.lower_rate
upper_rate = sr if sr is not None else self.optim.lr
return self._t(epoch) * (upper_rate - lr) + lr
2018-01-21 14:04:25 -08:00
def next(self):
if not super().next():
return False
e = self.epoch - 1
if e > 0 and e % self.frequency == 0:
if self.callback is not None:
self.callback(self.epoch // self.frequency)
return True
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
class SineCLR(TriangularCLR):
def _t(self, epoch):
return np.sin(_pi * _inv2 * super()._t(epoch))
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
class WaveCLR(TriangularCLR):
def _t(self, epoch):
return _inv2 * (_1 - np.cos(_pi * super()._t(epoch)))
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
# more
class PolyLearner(Learner):
per_batch = True
def __init__(self, optim, epochs=400, coeffs=(1,)):
self.coeffs = tuple(coeffs)
super().__init__(optim, epochs, np.polyval(self.coeffs, 0))
def rate_at(self, epoch):
progress = (epoch - 1) / (self.epochs)
ret = np.polyval(self.coeffs, progress)
return np.abs(ret)