fix SGDR restart iteration and add WaveCLR

This commit is contained in:
Connor Olding 2017-04-10 15:56:32 +00:00
parent a448ff3e8a
commit b6597e8b6c

View file

@ -1026,10 +1026,12 @@ class SGDR(Learner):
def split_num(self, epoch): def split_num(self, epoch):
shit = [0] + self.splits # hack shit = [0] + self.splits # hack
for i in range(0, len(self.splits)): for i in range(0, len(self.splits)):
if epoch < self.splits[i] + 1: if epoch < self.splits[i]:
sub_epoch = epoch - shit[i] sub_epoch = epoch - shit[i]
next_restart = self.splits[i] - shit[i] next_restart = self.splits[i] - shit[i]
return i, sub_epoch, next_restart return i, sub_epoch, next_restart
if epoch == self.splits[-1]:
return len(self.splits) - 1, epoch, self.splits[-1]
raise Exception('this should never happen.') raise Exception('this should never happen.')
def rate_at(self, epoch): def rate_at(self, epoch):
@ -1081,3 +1083,7 @@ class TriangularCLR(Learner):
class SineCLR(TriangularCLR): class SineCLR(TriangularCLR):
def _t(self, epoch): def _t(self, epoch):
return np.sin(_pi * _inv2 * super()._t(epoch)) return np.sin(_pi * _inv2 * super()._t(epoch))
class WaveCLR(TriangularCLR):
def _t(self, epoch):
return _inv2 * (_1 - np.cos(_pi * super()._t(epoch)))