From eb0011cc35ca0a81e649a49e9eb573fcb33f58e1 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Thu, 2 Feb 2017 14:25:40 -0800 Subject: [PATCH] . --- optim_nn.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 63 insertions(+), 9 deletions(-) diff --git a/optim_nn.py b/optim_nn.py index 48b06d9..d897c39 100644 --- a/optim_nn.py +++ b/optim_nn.py @@ -555,6 +555,9 @@ class Ritual: # i'm just making up names at this point batch_inputs = inputs[ bi:bi+batch_size] batch_outputs = outputs[bi:bi+batch_size] + if self.learner.per_batch: + self.learner.batch(b / batch_count) + predicted = model.forward(batch_inputs) residual = predicted - batch_outputs @@ -574,6 +577,8 @@ class Ritual: # i'm just making up names at this point return avg_loss class Learner: + per_batch = False + def __init__(self, optim, epochs=100, rate=None): assert isinstance(optim, Optimizer) self.optim = optim @@ -608,7 +613,7 @@ class Learner: return self.start_rate def next(self): - # returns whether or not to continue training. updates. + # prepares the next epoch. returns whether or not to continue training. if self.epoch + 1 >= self.epochs: return False if self.started: @@ -618,6 +623,13 @@ class Learner: self.epoch = self.epoch # poke property setter just in case return True + def batch(self, progress): # TODO: rename + # interpolates rates between epochs. + # unlike epochs, we do not store batch number as a state. + # i.e. calling next() will not respect progress. + assert 0 <= progress <= 1 + self.rate = self.rate_at(self._epoch + progress) + @property def final_rate(self): return self.rate_at(self.epochs - 1) @@ -629,7 +641,7 @@ class AnnealingLearner(Learner): super().__init__(optim, epochs, rate) def rate_at(self, epoch): - return super().rate_at(epoch) * self.anneal**epoch + return self.start_rate * self.anneal**epoch class DumbLearner(AnnealingLearner): # this is my own awful contraption. it's not really "SGD with restarts". @@ -656,6 +668,40 @@ class DumbLearner(AnnealingLearner): self.restart_callback(restart) return True +def cosmod(x): + # plot: https://www.desmos.com/calculator/hlgqmyswy2 + return (1 + np.cos((x % 1) * np.pi)) / 2 + +class SGDR(Learner): + # Stochastic Gradient Descent with Restarts + # paper: https://arxiv.org/abs/1608.03983 + # NOTE: this is not a complete implementation. + per_batch = True + + def __init__(self, optim, epochs=100, rate=None, restarts=0, restart_decay=0.5, callback=None): + self.restart_epochs = int(epochs) + self.decay = float(restart_decay) + self.restarts = int(restarts) + self.restart_callback = callback + epochs = self.restart_epochs * (self.restarts + 1) + super().__init__(optim, epochs, rate) + + def rate_at(self, epoch): + sub_epoch = epoch % self.restart_epochs + x = sub_epoch / self.restart_epochs + restart = epoch // self.restart_epochs + return self.start_rate * self.decay**restart * cosmod(x) + + def next(self): + if not super().next(): + return False + sub_epoch = self.epoch % self.restart_epochs + restart = self.epoch // self.restart_epochs + if restart > 0 and sub_epoch == 0: + if self.restart_callback is not None: + self.restart_callback(restart) + return True + def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batchless', init=init_he_normal): @@ -743,10 +789,11 @@ def run(program, args=[]): nesterov = False, # only used with SGD or Adam momentum = 0.33, # only used with SGD - # learning parameters: SGD with restarts (kinda) + # learning parameters + learner = 'SGDR', learn = 1e-2, epochs = 24, - learn_halve_every = 16, + learn_halve_every = 16, # 12 might be ideal for SGDR? restarts = 2, learn_restart_advance = 16, @@ -826,11 +873,18 @@ def run(program, args=[]): # - learner = DumbLearner(optim, epochs=config.epochs, rate=config.learn, - halve_every=config.learn_halve_every, - restarts=config.restarts, restart_advance=config.learn_restart_advance, - callback=rscb) - log("final learning rate", "{:10.8f}".format(learner.final_rate)) + if config.learner == 'SGDR': + decay = 0.5**(1/(config.epochs / config.learn_halve_every)) + learner = SGDR(optim, epochs=config.epochs, rate=config.learn, + restart_decay=decay, restarts=config.restarts, + callback=rscb) + # final learning rate isn't of interest here; it's gonna be close to 0. + else: + learner = DumbLearner(optim, epochs=config.epochs, rate=config.learn, + halve_every=config.learn_halve_every, + restarts=config.restarts, restart_advance=config.learn_restart_advance, + callback=rscb) + log("final learning rate", "{:10.8f}".format(learner.final_rate)) #