fix SGDR restart iteration and add WaveCLR
This commit is contained in:
parent
a448ff3e8a
commit
b6597e8b6c
1 changed files with 7 additions and 1 deletions
|
@ -1026,10 +1026,12 @@ class SGDR(Learner):
|
||||||
def split_num(self, epoch):
|
def split_num(self, epoch):
|
||||||
shit = [0] + self.splits # hack
|
shit = [0] + self.splits # hack
|
||||||
for i in range(0, len(self.splits)):
|
for i in range(0, len(self.splits)):
|
||||||
if epoch < self.splits[i] + 1:
|
if epoch < self.splits[i]:
|
||||||
sub_epoch = epoch - shit[i]
|
sub_epoch = epoch - shit[i]
|
||||||
next_restart = self.splits[i] - shit[i]
|
next_restart = self.splits[i] - shit[i]
|
||||||
return i, sub_epoch, next_restart
|
return i, sub_epoch, next_restart
|
||||||
|
if epoch == self.splits[-1]:
|
||||||
|
return len(self.splits) - 1, epoch, self.splits[-1]
|
||||||
raise Exception('this should never happen.')
|
raise Exception('this should never happen.')
|
||||||
|
|
||||||
def rate_at(self, epoch):
|
def rate_at(self, epoch):
|
||||||
|
@ -1081,3 +1083,7 @@ class TriangularCLR(Learner):
|
||||||
class SineCLR(TriangularCLR):
|
class SineCLR(TriangularCLR):
|
||||||
def _t(self, epoch):
|
def _t(self, epoch):
|
||||||
return np.sin(_pi * _inv2 * super()._t(epoch))
|
return np.sin(_pi * _inv2 * super()._t(epoch))
|
||||||
|
|
||||||
|
class WaveCLR(TriangularCLR):
|
||||||
|
def _t(self, epoch):
|
||||||
|
return _inv2 * (_1 - np.cos(_pi * super()._t(epoch)))
|
||||||
|
|
Loading…
Reference in a new issue