From b6597e8b6cfc811ff7d185b80e9a144c1c6ec3d0 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Mon, 10 Apr 2017 15:56:32 +0000 Subject: [PATCH] fix SGDR restart iteration and add WaveCLR --- optim_nn_core.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/optim_nn_core.py b/optim_nn_core.py index 0a52af9..1238d59 100644 --- a/optim_nn_core.py +++ b/optim_nn_core.py @@ -1026,10 +1026,12 @@ class SGDR(Learner): def split_num(self, epoch): shit = [0] + self.splits # hack for i in range(0, len(self.splits)): - if epoch < self.splits[i] + 1: + if epoch < self.splits[i]: sub_epoch = epoch - shit[i] next_restart = self.splits[i] - shit[i] return i, sub_epoch, next_restart + if epoch == self.splits[-1]: + return len(self.splits) - 1, epoch, self.splits[-1] raise Exception('this should never happen.') def rate_at(self, epoch): @@ -1081,3 +1083,7 @@ class TriangularCLR(Learner): class SineCLR(TriangularCLR): def _t(self, epoch): return np.sin(_pi * _inv2 * super()._t(epoch)) + +class WaveCLR(TriangularCLR): + def _t(self, epoch): + return _inv2 * (_1 - np.cos(_pi * super()._t(epoch)))