.
This commit is contained in:
parent
3fd6ef9688
commit
eb0011cc35
1 changed files with 63 additions and 9 deletions
72
optim_nn.py
72
optim_nn.py
|
@ -555,6 +555,9 @@ class Ritual: # i'm just making up names at this point
|
|||
batch_inputs = inputs[ bi:bi+batch_size]
|
||||
batch_outputs = outputs[bi:bi+batch_size]
|
||||
|
||||
if self.learner.per_batch:
|
||||
self.learner.batch(b / batch_count)
|
||||
|
||||
predicted = model.forward(batch_inputs)
|
||||
residual = predicted - batch_outputs
|
||||
|
||||
|
@ -574,6 +577,8 @@ class Ritual: # i'm just making up names at this point
|
|||
return avg_loss
|
||||
|
||||
class Learner:
|
||||
per_batch = False
|
||||
|
||||
def __init__(self, optim, epochs=100, rate=None):
|
||||
assert isinstance(optim, Optimizer)
|
||||
self.optim = optim
|
||||
|
@ -608,7 +613,7 @@ class Learner:
|
|||
return self.start_rate
|
||||
|
||||
def next(self):
|
||||
# returns whether or not to continue training. updates.
|
||||
# prepares the next epoch. returns whether or not to continue training.
|
||||
if self.epoch + 1 >= self.epochs:
|
||||
return False
|
||||
if self.started:
|
||||
|
@ -618,6 +623,13 @@ class Learner:
|
|||
self.epoch = self.epoch # poke property setter just in case
|
||||
return True
|
||||
|
||||
def batch(self, progress): # TODO: rename
|
||||
# interpolates rates between epochs.
|
||||
# unlike epochs, we do not store batch number as a state.
|
||||
# i.e. calling next() will not respect progress.
|
||||
assert 0 <= progress <= 1
|
||||
self.rate = self.rate_at(self._epoch + progress)
|
||||
|
||||
@property
|
||||
def final_rate(self):
|
||||
return self.rate_at(self.epochs - 1)
|
||||
|
@ -629,7 +641,7 @@ class AnnealingLearner(Learner):
|
|||
super().__init__(optim, epochs, rate)
|
||||
|
||||
def rate_at(self, epoch):
|
||||
return super().rate_at(epoch) * self.anneal**epoch
|
||||
return self.start_rate * self.anneal**epoch
|
||||
|
||||
class DumbLearner(AnnealingLearner):
|
||||
# this is my own awful contraption. it's not really "SGD with restarts".
|
||||
|
@ -656,6 +668,40 @@ class DumbLearner(AnnealingLearner):
|
|||
self.restart_callback(restart)
|
||||
return True
|
||||
|
||||
def cosmod(x):
|
||||
# plot: https://www.desmos.com/calculator/hlgqmyswy2
|
||||
return (1 + np.cos((x % 1) * np.pi)) / 2
|
||||
|
||||
class SGDR(Learner):
|
||||
# Stochastic Gradient Descent with Restarts
|
||||
# paper: https://arxiv.org/abs/1608.03983
|
||||
# NOTE: this is not a complete implementation.
|
||||
per_batch = True
|
||||
|
||||
def __init__(self, optim, epochs=100, rate=None, restarts=0, restart_decay=0.5, callback=None):
|
||||
self.restart_epochs = int(epochs)
|
||||
self.decay = float(restart_decay)
|
||||
self.restarts = int(restarts)
|
||||
self.restart_callback = callback
|
||||
epochs = self.restart_epochs * (self.restarts + 1)
|
||||
super().__init__(optim, epochs, rate)
|
||||
|
||||
def rate_at(self, epoch):
|
||||
sub_epoch = epoch % self.restart_epochs
|
||||
x = sub_epoch / self.restart_epochs
|
||||
restart = epoch // self.restart_epochs
|
||||
return self.start_rate * self.decay**restart * cosmod(x)
|
||||
|
||||
def next(self):
|
||||
if not super().next():
|
||||
return False
|
||||
sub_epoch = self.epoch % self.restart_epochs
|
||||
restart = self.epoch // self.restart_epochs
|
||||
if restart > 0 and sub_epoch == 0:
|
||||
if self.restart_callback is not None:
|
||||
self.restart_callback(restart)
|
||||
return True
|
||||
|
||||
def multiresnet(x, width, depth, block=2, multi=1,
|
||||
activation=Relu, style='batchless',
|
||||
init=init_he_normal):
|
||||
|
@ -743,10 +789,11 @@ def run(program, args=[]):
|
|||
nesterov = False, # only used with SGD or Adam
|
||||
momentum = 0.33, # only used with SGD
|
||||
|
||||
# learning parameters: SGD with restarts (kinda)
|
||||
# learning parameters
|
||||
learner = 'SGDR',
|
||||
learn = 1e-2,
|
||||
epochs = 24,
|
||||
learn_halve_every = 16,
|
||||
learn_halve_every = 16, # 12 might be ideal for SGDR?
|
||||
restarts = 2,
|
||||
learn_restart_advance = 16,
|
||||
|
||||
|
@ -826,11 +873,18 @@ def run(program, args=[]):
|
|||
|
||||
#
|
||||
|
||||
learner = DumbLearner(optim, epochs=config.epochs, rate=config.learn,
|
||||
halve_every=config.learn_halve_every,
|
||||
restarts=config.restarts, restart_advance=config.learn_restart_advance,
|
||||
callback=rscb)
|
||||
log("final learning rate", "{:10.8f}".format(learner.final_rate))
|
||||
if config.learner == 'SGDR':
|
||||
decay = 0.5**(1/(config.epochs / config.learn_halve_every))
|
||||
learner = SGDR(optim, epochs=config.epochs, rate=config.learn,
|
||||
restart_decay=decay, restarts=config.restarts,
|
||||
callback=rscb)
|
||||
# final learning rate isn't of interest here; it's gonna be close to 0.
|
||||
else:
|
||||
learner = DumbLearner(optim, epochs=config.epochs, rate=config.learn,
|
||||
halve_every=config.learn_halve_every,
|
||||
restarts=config.restarts, restart_advance=config.learn_restart_advance,
|
||||
callback=rscb)
|
||||
log("final learning rate", "{:10.8f}".format(learner.final_rate))
|
||||
|
||||
#
|
||||
|
||||
|
|
Loading…
Reference in a new issue