diff --git a/optim_nn_core.py b/optim_nn_core.py index 0a52af9..1238d59 100644 --- a/optim_nn_core.py +++ b/optim_nn_core.py @@ -1026,10 +1026,12 @@ class SGDR(Learner): def split_num(self, epoch): shit = [0] + self.splits # hack for i in range(0, len(self.splits)): - if epoch < self.splits[i] + 1: + if epoch < self.splits[i]: sub_epoch = epoch - shit[i] next_restart = self.splits[i] - shit[i] return i, sub_epoch, next_restart + if epoch == self.splits[-1]: + return len(self.splits) - 1, epoch, self.splits[-1] raise Exception('this should never happen.') def rate_at(self, epoch): @@ -1081,3 +1083,7 @@ class TriangularCLR(Learner): class SineCLR(TriangularCLR): def _t(self, epoch): return np.sin(_pi * _inv2 * super()._t(epoch)) + +class WaveCLR(TriangularCLR): + def _t(self, epoch): + return _inv2 * (_1 - np.cos(_pi * super()._t(epoch)))