refactor gradient filtering

This commit is contained in:
Connor Olding 2019-02-03 15:10:43 +01:00
parent 0d28882ef0
commit 54ea41711b

View file

@ -12,12 +12,13 @@ def filter_gradients(accum, grads, param):
# param == 0 simply copies grads into accum. # param == 0 simply copies grads into accum.
if param == 0: if param == 0:
accum[:] = grads accum[:] = grads
if param < 0:
if param != -1:
accum *= -param
accum += grads
elif param == 1: elif param == 1:
pass pass
elif param == -1:
accum += grads
elif param < 0:
accum *= -param
accum += grads
else: else:
accum += (1 - param) * (grads - accum) accum += (1 - param) * (grads - accum)
return accum return accum