diff --git a/onn/optimizer.py b/onn/optimizer.py index 2aa05f8..468741e 100644 --- a/onn/optimizer.py +++ b/onn/optimizer.py @@ -12,12 +12,13 @@ def filter_gradients(accum, grads, param): # param == 0 simply copies grads into accum. if param == 0: accum[:] = grads - if param < 0: - if param != -1: - accum *= -param - accum += grads elif param == 1: pass + elif param == -1: + accum += grads + elif param < 0: + accum *= -param + accum += grads else: accum += (1 - param) * (grads - accum) return accum