fix SGDR restart iteration and add WaveCLR
This commit is contained in:
parent
a448ff3e8a
commit
b6597e8b6c
1 changed files with 7 additions and 1 deletions
|
@ -1026,10 +1026,12 @@ class SGDR(Learner):
|
|||
def split_num(self, epoch):
|
||||
shit = [0] + self.splits # hack
|
||||
for i in range(0, len(self.splits)):
|
||||
if epoch < self.splits[i] + 1:
|
||||
if epoch < self.splits[i]:
|
||||
sub_epoch = epoch - shit[i]
|
||||
next_restart = self.splits[i] - shit[i]
|
||||
return i, sub_epoch, next_restart
|
||||
if epoch == self.splits[-1]:
|
||||
return len(self.splits) - 1, epoch, self.splits[-1]
|
||||
raise Exception('this should never happen.')
|
||||
|
||||
def rate_at(self, epoch):
|
||||
|
@ -1081,3 +1083,7 @@ class TriangularCLR(Learner):
|
|||
class SineCLR(TriangularCLR):
|
||||
def _t(self, epoch):
|
||||
return np.sin(_pi * _inv2 * super()._t(epoch))
|
||||
|
||||
class WaveCLR(TriangularCLR):
|
||||
def _t(self, epoch):
|
||||
return _inv2 * (_1 - np.cos(_pi * super()._t(epoch)))
|
||||
|
|
Loading…
Reference in a new issue