diff --git a/optim_nn_core.py b/optim_nn_core.py index d54056e..38f1eda 100644 --- a/optim_nn_core.py +++ b/optim_nn_core.py @@ -1090,13 +1090,16 @@ class SGDR(Learner): def __init__(self, optim, epochs=100, rate=None, restarts=0, restart_decay=0.5, callback=None, - expando=None): + expando=0): self.restart_epochs = int(epochs) self.decay = _f(restart_decay) self.restarts = int(restarts) self.restart_callback = callback # TODO: rename expando to something not insane self.expando = expando if expando is not None else lambda i: i + if type(self.expando) == int: + inc = self.expando + self.expando = self.expando = lambda i: inc self.splits = [] epochs = 0