diff --git a/onn/optimizer.py b/onn/optimizer.py index 468741e..dfc483a 100644 --- a/onn/optimizer.py +++ b/onn/optimizer.py @@ -34,6 +34,8 @@ class Momentum(Optimizer): def reset(self): self.Vprev = None + super().reset() + def compute(self, dW, W): if self.Vprev is None: self.Vprev = np.copy(dW) @@ -59,6 +61,8 @@ class Adadelta(Optimizer): self.g = None self.x = None + super().reset() + def compute(self, dW, W): if self.g is None: self.g = np.zeros_like(dW) @@ -88,6 +92,8 @@ class RMSpropCentered(Optimizer): self.vt = None self.delta = None + super().reset() + def compute(self, dW, W): if self.g is None: self.g = np.zeros_like(dW) @@ -136,6 +142,8 @@ class Nadam(Optimizer): self.t = 0 self.sched = 1 + super().reset() + def compute(self, dW, W): self.t += 1 @@ -184,6 +192,8 @@ class FTML(Optimizer): self.b1_t = _1 self.b2_t = _1 + super().reset() + def compute(self, dW, W): if self.dt1 is None: self.dt1 = np.zeros_like(dW) @@ -231,6 +241,8 @@ class MomentumClip(Optimizer): def reset(self): self.accum = None + super().reset() + def compute(self, dW, W): if self.accum is None: self.accum = np.zeros_like(dW) @@ -261,6 +273,8 @@ class AddSign(Optimizer): def reset(self): self.accum = None + super().reset() + def compute(self, dW, W): if self.accum is None: self.accum = np.zeros_like(dW) @@ -286,6 +300,8 @@ class PowerSign(Optimizer): def reset(self): self.accum = None + super().reset() + def compute(self, dW, W): if self.accum is None: self.accum = np.zeros_like(dW) @@ -326,6 +342,8 @@ class Neumann(Optimizer): self.vt = None # weight accumulator. self.t = 0 + super().reset() + def compute(self, dW, W): raise Exception("compute() is not available for this Optimizer.") @@ -389,6 +407,8 @@ class Adamlike(Optimizer): self.b1_t = self.b1_t_default self.b2_t = self.b2_t_default + super().reset() + def compute(self, dW, W): if self.mt is None: self.mt = np.zeros_like(dW) diff --git a/onn/optimizer_base.py b/onn/optimizer_base.py index 6977539..6131ad0 100644 --- a/onn/optimizer_base.py +++ b/onn/optimizer_base.py @@ -6,10 +6,11 @@ from .float import _f class Optimizer: def __init__(self, lr=0.1): self.lr = _f(lr) # learning rate + self.base_rate = self.lr self.reset() def reset(self): - pass + self.lr = self.base_rate def compute(self, dW, W): return -self.lr * dW