diff --git a/onn.py b/onn.py index 5a86f97..c5639bf 100755 --- a/onn.py +++ b/onn.py @@ -766,6 +766,18 @@ class NoisyRitual(Ritual): # Learners {{{1 +class PolyLearner(Learner): + per_batch = True + + def __init__(self, optim, epochs=400, coeffs=(1,)): + self.coeffs = tuple(coeffs) + super().__init__(optim, epochs, np.polyval(self.coeffs, 0)) + + def rate_at(self, epoch): + progress = (epoch - 1) / (self.epochs) + ret = np.polyval(self.coeffs, progress) + return np.abs(ret) + class DumbLearner(AnnealingLearner): # this is my own awful contraption. it's not really "SGD with restarts". def __init__(self, optim, epochs=100, rate=None, halve_every=10,