diff --git a/onn_core.py b/onn_core.py index e71118b..e8a948a 100644 --- a/onn_core.py +++ b/onn_core.py @@ -1262,9 +1262,10 @@ class SGDR(Learner): raise Exception('this should never happen.') def rate_at(self, epoch): + base_rate = self.start_rate if self.start_rate is not None else self.optim.lr restart, sub_epoch, next_restart = self.split_num(max(1, epoch)) x = _f(sub_epoch - 1) / _f(next_restart) - return self.start_rate * self.decay**_f(restart) * cosmod(x) + return base_rate * self.decay**_f(restart) * cosmod(x) def next(self): if not super().next():