finally fix learning rate scheduling for real

okay, this is a disaster, but i think i've got it under control now.

the way batch-based learners now work is:
the epoch we're working towards is the truncated part of the epoch variable,
and how far we are into the epoch is the fractional part.

epoch starts at 1, so subtract by 1 when doing periodic operations.
This commit is contained in:
Connor Olding 2017-07-25 04:25:35 +00:00
parent 93547b1974
commit 2cf38d4ece

View File

@ -1155,7 +1155,7 @@ class Learner:
# unlike epochs, we do not store batch number as a state.
# i.e. calling next() will not respect progress.
assert 0 <= progress <= 1
self.rate = self.rate_at(self._epoch - 1 + progress)
self.rate = self.rate_at(self._epoch + progress)
@property
def final_rate(self):
@ -1188,6 +1188,7 @@ class SGDR(Learner):
self.decay = _f(restart_decay)
self.restarts = int(restarts)
self.restart_callback = callback
# TODO: rename expando to something not insane
self.expando = expando if expando is not None else lambda i: i
if type(self.expando) == int:
@ -1203,26 +1204,24 @@ class SGDR(Learner):
super().__init__(optim, epochs, rate)
def split_num(self, epoch):
shit = [0] + self.splits # hack
for i in range(0, len(self.splits)):
if epoch < self.splits[i]:
sub_epoch = epoch - shit[i]
next_restart = self.splits[i] - shit[i]
previous = [0] + self.splits
for i, split in enumerate(self.splits):
if epoch - 1 < split:
sub_epoch = epoch - previous[i]
next_restart = split - previous[i]
return i, sub_epoch, next_restart
if epoch == self.splits[-1]:
return len(self.splits) - 1, epoch, self.splits[-1]
raise Exception('this should never happen.')
def rate_at(self, epoch):
restart, sub_epoch, next_restart = self.split_num(epoch)
x = _f(sub_epoch) / _f(next_restart)
restart, sub_epoch, next_restart = self.split_num(max(1, epoch))
x = _f(sub_epoch - 1) / _f(next_restart)
return self.start_rate * self.decay**_f(restart) * cosmod(x)
def next(self):
if not super().next():
return False
restart, sub_epoch, next_restart = self.split_num(self.epoch)
if restart > 0 and sub_epoch == 0:
if restart > 0 and sub_epoch == 1:
if self.restart_callback is not None:
self.restart_callback(restart)
return True
@ -1245,7 +1244,7 @@ class TriangularCLR(Learner):
def _t(self, epoch):
# NOTE: this could probably be simplified
offset = self.frequency / 2
return np.abs(((epoch + offset) % self.frequency) - offset) / offset
return np.abs(((epoch - 1 + offset) % self.frequency) - offset) / offset
def rate_at(self, epoch):
# NOTE: start_rate is treated as upper_rate
@ -1254,7 +1253,8 @@ class TriangularCLR(Learner):
def next(self):
if not super().next():
return False
if self.epoch > 1 and self.epoch % self.frequency == 0:
e = self.epoch - 1
if e > 0 and e % self.frequency == 0:
if self.callback is not None:
self.callback(self.epoch // self.frequency)
return True