.
This commit is contained in:
parent
3fd6ef9688
commit
eb0011cc35
62
optim_nn.py
62
optim_nn.py
|
@ -555,6 +555,9 @@ class Ritual: # i'm just making up names at this point
|
||||||
batch_inputs = inputs[ bi:bi+batch_size]
|
batch_inputs = inputs[ bi:bi+batch_size]
|
||||||
batch_outputs = outputs[bi:bi+batch_size]
|
batch_outputs = outputs[bi:bi+batch_size]
|
||||||
|
|
||||||
|
if self.learner.per_batch:
|
||||||
|
self.learner.batch(b / batch_count)
|
||||||
|
|
||||||
predicted = model.forward(batch_inputs)
|
predicted = model.forward(batch_inputs)
|
||||||
residual = predicted - batch_outputs
|
residual = predicted - batch_outputs
|
||||||
|
|
||||||
|
@ -574,6 +577,8 @@ class Ritual: # i'm just making up names at this point
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
class Learner:
|
class Learner:
|
||||||
|
per_batch = False
|
||||||
|
|
||||||
def __init__(self, optim, epochs=100, rate=None):
|
def __init__(self, optim, epochs=100, rate=None):
|
||||||
assert isinstance(optim, Optimizer)
|
assert isinstance(optim, Optimizer)
|
||||||
self.optim = optim
|
self.optim = optim
|
||||||
|
@ -608,7 +613,7 @@ class Learner:
|
||||||
return self.start_rate
|
return self.start_rate
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
# returns whether or not to continue training. updates.
|
# prepares the next epoch. returns whether or not to continue training.
|
||||||
if self.epoch + 1 >= self.epochs:
|
if self.epoch + 1 >= self.epochs:
|
||||||
return False
|
return False
|
||||||
if self.started:
|
if self.started:
|
||||||
|
@ -618,6 +623,13 @@ class Learner:
|
||||||
self.epoch = self.epoch # poke property setter just in case
|
self.epoch = self.epoch # poke property setter just in case
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def batch(self, progress): # TODO: rename
|
||||||
|
# interpolates rates between epochs.
|
||||||
|
# unlike epochs, we do not store batch number as a state.
|
||||||
|
# i.e. calling next() will not respect progress.
|
||||||
|
assert 0 <= progress <= 1
|
||||||
|
self.rate = self.rate_at(self._epoch + progress)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def final_rate(self):
|
def final_rate(self):
|
||||||
return self.rate_at(self.epochs - 1)
|
return self.rate_at(self.epochs - 1)
|
||||||
|
@ -629,7 +641,7 @@ class AnnealingLearner(Learner):
|
||||||
super().__init__(optim, epochs, rate)
|
super().__init__(optim, epochs, rate)
|
||||||
|
|
||||||
def rate_at(self, epoch):
|
def rate_at(self, epoch):
|
||||||
return super().rate_at(epoch) * self.anneal**epoch
|
return self.start_rate * self.anneal**epoch
|
||||||
|
|
||||||
class DumbLearner(AnnealingLearner):
|
class DumbLearner(AnnealingLearner):
|
||||||
# this is my own awful contraption. it's not really "SGD with restarts".
|
# this is my own awful contraption. it's not really "SGD with restarts".
|
||||||
|
@ -656,6 +668,40 @@ class DumbLearner(AnnealingLearner):
|
||||||
self.restart_callback(restart)
|
self.restart_callback(restart)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def cosmod(x):
|
||||||
|
# plot: https://www.desmos.com/calculator/hlgqmyswy2
|
||||||
|
return (1 + np.cos((x % 1) * np.pi)) / 2
|
||||||
|
|
||||||
|
class SGDR(Learner):
|
||||||
|
# Stochastic Gradient Descent with Restarts
|
||||||
|
# paper: https://arxiv.org/abs/1608.03983
|
||||||
|
# NOTE: this is not a complete implementation.
|
||||||
|
per_batch = True
|
||||||
|
|
||||||
|
def __init__(self, optim, epochs=100, rate=None, restarts=0, restart_decay=0.5, callback=None):
|
||||||
|
self.restart_epochs = int(epochs)
|
||||||
|
self.decay = float(restart_decay)
|
||||||
|
self.restarts = int(restarts)
|
||||||
|
self.restart_callback = callback
|
||||||
|
epochs = self.restart_epochs * (self.restarts + 1)
|
||||||
|
super().__init__(optim, epochs, rate)
|
||||||
|
|
||||||
|
def rate_at(self, epoch):
|
||||||
|
sub_epoch = epoch % self.restart_epochs
|
||||||
|
x = sub_epoch / self.restart_epochs
|
||||||
|
restart = epoch // self.restart_epochs
|
||||||
|
return self.start_rate * self.decay**restart * cosmod(x)
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
if not super().next():
|
||||||
|
return False
|
||||||
|
sub_epoch = self.epoch % self.restart_epochs
|
||||||
|
restart = self.epoch // self.restart_epochs
|
||||||
|
if restart > 0 and sub_epoch == 0:
|
||||||
|
if self.restart_callback is not None:
|
||||||
|
self.restart_callback(restart)
|
||||||
|
return True
|
||||||
|
|
||||||
def multiresnet(x, width, depth, block=2, multi=1,
|
def multiresnet(x, width, depth, block=2, multi=1,
|
||||||
activation=Relu, style='batchless',
|
activation=Relu, style='batchless',
|
||||||
init=init_he_normal):
|
init=init_he_normal):
|
||||||
|
@ -743,10 +789,11 @@ def run(program, args=[]):
|
||||||
nesterov = False, # only used with SGD or Adam
|
nesterov = False, # only used with SGD or Adam
|
||||||
momentum = 0.33, # only used with SGD
|
momentum = 0.33, # only used with SGD
|
||||||
|
|
||||||
# learning parameters: SGD with restarts (kinda)
|
# learning parameters
|
||||||
|
learner = 'SGDR',
|
||||||
learn = 1e-2,
|
learn = 1e-2,
|
||||||
epochs = 24,
|
epochs = 24,
|
||||||
learn_halve_every = 16,
|
learn_halve_every = 16, # 12 might be ideal for SGDR?
|
||||||
restarts = 2,
|
restarts = 2,
|
||||||
learn_restart_advance = 16,
|
learn_restart_advance = 16,
|
||||||
|
|
||||||
|
@ -826,6 +873,13 @@ def run(program, args=[]):
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|
||||||
|
if config.learner == 'SGDR':
|
||||||
|
decay = 0.5**(1/(config.epochs / config.learn_halve_every))
|
||||||
|
learner = SGDR(optim, epochs=config.epochs, rate=config.learn,
|
||||||
|
restart_decay=decay, restarts=config.restarts,
|
||||||
|
callback=rscb)
|
||||||
|
# final learning rate isn't of interest here; it's gonna be close to 0.
|
||||||
|
else:
|
||||||
learner = DumbLearner(optim, epochs=config.epochs, rate=config.learn,
|
learner = DumbLearner(optim, epochs=config.epochs, rate=config.learn,
|
||||||
halve_every=config.learn_halve_every,
|
halve_every=config.learn_halve_every,
|
||||||
restarts=config.restarts, restart_advance=config.learn_restart_advance,
|
restarts=config.restarts, restart_advance=config.learn_restart_advance,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user