diff --git a/onn/optimizer.py b/onn/optimizer.py index b6cda0f..a349c11 100644 --- a/onn/optimizer.py +++ b/onn/optimizer.py @@ -26,26 +26,25 @@ def filter_gradients(accum, grads, param): class Momentum(Optimizer): def __init__(self, lr=0.01, mu=0.9, nesterov=False): - self.mu = _f(mu) # momentum + self.mu = _f(mu) self.nesterov = bool(nesterov) super().__init__(lr) def reset(self): - self.Vprev = None + self.accum = None super().reset() def compute(self, dW, W): - if self.Vprev is None: - self.Vprev = np.copy(dW) + if self.accum is None: + self.accum = np.zeros_like(dW) - V = self.mu * self.Vprev - self.lr * dW - self.Vprev[:] = V + self.accum[:] = self.accum * self.mu + dW if self.nesterov: - return self.mu * V - self.lr * dW - - return V + return -self.lr * (self.accum * self.mu + dW) + else: + return -self.lr * self.accum class Adadelta(Optimizer):