finally fix learning rate scheduling for real
okay, this is a disaster, but i think i've got it under control now. the way batch-based learners now work is: the epoch we're working towards is the truncated part of the epoch variable, and how far we are into the epoch is the fractional part. epoch starts at 1, so subtract by 1 when doing periodic operations.
This commit is contained in:
parent
93547b1974
commit
2cf38d4ece
26
onn_core.py
26
onn_core.py
|
@ -1155,7 +1155,7 @@ class Learner:
|
|||
# unlike epochs, we do not store batch number as a state.
|
||||
# i.e. calling next() will not respect progress.
|
||||
assert 0 <= progress <= 1
|
||||
self.rate = self.rate_at(self._epoch - 1 + progress)
|
||||
self.rate = self.rate_at(self._epoch + progress)
|
||||
|
||||
@property
|
||||
def final_rate(self):
|
||||
|
@ -1188,6 +1188,7 @@ class SGDR(Learner):
|
|||
self.decay = _f(restart_decay)
|
||||
self.restarts = int(restarts)
|
||||
self.restart_callback = callback
|
||||
|
||||
# TODO: rename expando to something not insane
|
||||
self.expando = expando if expando is not None else lambda i: i
|
||||
if type(self.expando) == int:
|
||||
|
@ -1203,26 +1204,24 @@ class SGDR(Learner):
|
|||
super().__init__(optim, epochs, rate)
|
||||
|
||||
def split_num(self, epoch):
|
||||
shit = [0] + self.splits # hack
|
||||
for i in range(0, len(self.splits)):
|
||||
if epoch < self.splits[i]:
|
||||
sub_epoch = epoch - shit[i]
|
||||
next_restart = self.splits[i] - shit[i]
|
||||
previous = [0] + self.splits
|
||||
for i, split in enumerate(self.splits):
|
||||
if epoch - 1 < split:
|
||||
sub_epoch = epoch - previous[i]
|
||||
next_restart = split - previous[i]
|
||||
return i, sub_epoch, next_restart
|
||||
if epoch == self.splits[-1]:
|
||||
return len(self.splits) - 1, epoch, self.splits[-1]
|
||||
raise Exception('this should never happen.')
|
||||
|
||||
def rate_at(self, epoch):
|
||||
restart, sub_epoch, next_restart = self.split_num(epoch)
|
||||
x = _f(sub_epoch) / _f(next_restart)
|
||||
restart, sub_epoch, next_restart = self.split_num(max(1, epoch))
|
||||
x = _f(sub_epoch - 1) / _f(next_restart)
|
||||
return self.start_rate * self.decay**_f(restart) * cosmod(x)
|
||||
|
||||
def next(self):
|
||||
if not super().next():
|
||||
return False
|
||||
restart, sub_epoch, next_restart = self.split_num(self.epoch)
|
||||
if restart > 0 and sub_epoch == 0:
|
||||
if restart > 0 and sub_epoch == 1:
|
||||
if self.restart_callback is not None:
|
||||
self.restart_callback(restart)
|
||||
return True
|
||||
|
@ -1245,7 +1244,7 @@ class TriangularCLR(Learner):
|
|||
def _t(self, epoch):
|
||||
# NOTE: this could probably be simplified
|
||||
offset = self.frequency / 2
|
||||
return np.abs(((epoch + offset) % self.frequency) - offset) / offset
|
||||
return np.abs(((epoch - 1 + offset) % self.frequency) - offset) / offset
|
||||
|
||||
def rate_at(self, epoch):
|
||||
# NOTE: start_rate is treated as upper_rate
|
||||
|
@ -1254,7 +1253,8 @@ class TriangularCLR(Learner):
|
|||
def next(self):
|
||||
if not super().next():
|
||||
return False
|
||||
if self.epoch > 1 and self.epoch % self.frequency == 0:
|
||||
e = self.epoch - 1
|
||||
if e > 0 and e % self.frequency == 0:
|
||||
if self.callback is not None:
|
||||
self.callback(self.epoch // self.frequency)
|
||||
return True
|
||||
|
|
Loading…
Reference in New Issue
Block a user