This commit is contained in:
Connor Olding 2017-02-02 14:25:40 -08:00
parent 3fd6ef9688
commit eb0011cc35

View file

@ -555,6 +555,9 @@ class Ritual: # i'm just making up names at this point
batch_inputs = inputs[ bi:bi+batch_size]
batch_outputs = outputs[bi:bi+batch_size]
if self.learner.per_batch:
self.learner.batch(b / batch_count)
predicted = model.forward(batch_inputs)
residual = predicted - batch_outputs
@ -574,6 +577,8 @@ class Ritual: # i'm just making up names at this point
return avg_loss
class Learner:
per_batch = False
def __init__(self, optim, epochs=100, rate=None):
assert isinstance(optim, Optimizer)
self.optim = optim
@ -608,7 +613,7 @@ class Learner:
return self.start_rate
def next(self):
# returns whether or not to continue training. updates.
# prepares the next epoch. returns whether or not to continue training.
if self.epoch + 1 >= self.epochs:
return False
if self.started:
@ -618,6 +623,13 @@ class Learner:
self.epoch = self.epoch # poke property setter just in case
return True
def batch(self, progress): # TODO: rename
# interpolates rates between epochs.
# unlike epochs, we do not store batch number as a state.
# i.e. calling next() will not respect progress.
assert 0 <= progress <= 1
self.rate = self.rate_at(self._epoch + progress)
@property
def final_rate(self):
return self.rate_at(self.epochs - 1)
@ -629,7 +641,7 @@ class AnnealingLearner(Learner):
super().__init__(optim, epochs, rate)
def rate_at(self, epoch):
return super().rate_at(epoch) * self.anneal**epoch
return self.start_rate * self.anneal**epoch
class DumbLearner(AnnealingLearner):
# this is my own awful contraption. it's not really "SGD with restarts".
@ -656,6 +668,40 @@ class DumbLearner(AnnealingLearner):
self.restart_callback(restart)
return True
def cosmod(x):
# plot: https://www.desmos.com/calculator/hlgqmyswy2
return (1 + np.cos((x % 1) * np.pi)) / 2
class SGDR(Learner):
# Stochastic Gradient Descent with Restarts
# paper: https://arxiv.org/abs/1608.03983
# NOTE: this is not a complete implementation.
per_batch = True
def __init__(self, optim, epochs=100, rate=None, restarts=0, restart_decay=0.5, callback=None):
self.restart_epochs = int(epochs)
self.decay = float(restart_decay)
self.restarts = int(restarts)
self.restart_callback = callback
epochs = self.restart_epochs * (self.restarts + 1)
super().__init__(optim, epochs, rate)
def rate_at(self, epoch):
sub_epoch = epoch % self.restart_epochs
x = sub_epoch / self.restart_epochs
restart = epoch // self.restart_epochs
return self.start_rate * self.decay**restart * cosmod(x)
def next(self):
if not super().next():
return False
sub_epoch = self.epoch % self.restart_epochs
restart = self.epoch // self.restart_epochs
if restart > 0 and sub_epoch == 0:
if self.restart_callback is not None:
self.restart_callback(restart)
return True
def multiresnet(x, width, depth, block=2, multi=1,
activation=Relu, style='batchless',
init=init_he_normal):
@ -743,10 +789,11 @@ def run(program, args=[]):
nesterov = False, # only used with SGD or Adam
momentum = 0.33, # only used with SGD
# learning parameters: SGD with restarts (kinda)
# learning parameters
learner = 'SGDR',
learn = 1e-2,
epochs = 24,
learn_halve_every = 16,
learn_halve_every = 16, # 12 might be ideal for SGDR?
restarts = 2,
learn_restart_advance = 16,
@ -826,11 +873,18 @@ def run(program, args=[]):
#
learner = DumbLearner(optim, epochs=config.epochs, rate=config.learn,
halve_every=config.learn_halve_every,
restarts=config.restarts, restart_advance=config.learn_restart_advance,
callback=rscb)
log("final learning rate", "{:10.8f}".format(learner.final_rate))
if config.learner == 'SGDR':
decay = 0.5**(1/(config.epochs / config.learn_halve_every))
learner = SGDR(optim, epochs=config.epochs, rate=config.learn,
restart_decay=decay, restarts=config.restarts,
callback=rscb)
# final learning rate isn't of interest here; it's gonna be close to 0.
else:
learner = DumbLearner(optim, epochs=config.epochs, rate=config.learn,
halve_every=config.learn_halve_every,
restarts=config.restarts, restart_advance=config.learn_restart_advance,
callback=rscb)
log("final learning rate", "{:10.8f}".format(learner.final_rate))
#