From 1b1184480a678bd2e78ba4ce7d35ebbcecf32541 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Sun, 2 Jul 2017 02:52:07 +0000 Subject: [PATCH] allow optimizers to adjust their own learning rate --- onn_core.py | 9 +++++---- onn_mnist.py | 6 ++++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/onn_core.py b/onn_core.py index d6b552a..6ad5116 100644 --- a/onn_core.py +++ b/onn_core.py @@ -314,9 +314,8 @@ class RMSprop(Optimizer): class Adam(Optimizer): # paper: https://arxiv.org/abs/1412.6980 # Adam generalizes* RMSprop, and - # adds a decay term to the regular (non-squared) delta, and - # does some decay-gain voodoo. (i guess it's compensating - # for the filtered deltas starting from zero) + # adds a decay term to the regular (non-squared) delta, and performs + # debiasing to compensate for the filtered deltas starting from zero. # * Adam == RMSprop when # Adam.b1 == 0 @@ -1072,7 +1071,7 @@ class Learner: def __init__(self, optim, epochs=100, rate=None): assert isinstance(optim, Optimizer) self.optim = optim - self.start_rate = optim.alpha if rate is None else _f(rate) + self.start_rate = rate # None is okay; it'll use optim.alpha instead. self.epochs = int(epochs) self.reset() @@ -1101,6 +1100,8 @@ class Learner: self.optim.alpha = new_rate def rate_at(self, epoch): + if self.start_rate is None: + return self.optim.alpha return self.start_rate def next(self): diff --git a/onn_mnist.py b/onn_mnist.py index 4fdcfe0..e462c9e 100755 --- a/onn_mnist.py +++ b/onn_mnist.py @@ -136,10 +136,12 @@ if learner_class == SGDR: learner = learner_class(optim, epochs=epochs//starts, rate=lr, restarts=starts-1, restart_decay=restart_decay, expando=lambda i:0) -else: - assert learner_class in (TriangularCLR, SineCLR, WaveCLR) +elif learner_class in (TriangularCLR, SineCLR, WaveCLR): learner = learner_class(optim, epochs=epochs, lower_rate=0, upper_rate=lr, frequency=epochs//starts) +else: + lament('NOTE: no learning rate schedule selected.') + learner = Learner(optim, epochs=epochs) loss = CategoricalCrossentropy() mloss = Accuracy()