refactor gradient filtering
This commit is contained in:
parent
0d28882ef0
commit
54ea41711b
1 changed files with 5 additions and 4 deletions
|
@ -12,12 +12,13 @@ def filter_gradients(accum, grads, param):
|
||||||
# param == 0 simply copies grads into accum.
|
# param == 0 simply copies grads into accum.
|
||||||
if param == 0:
|
if param == 0:
|
||||||
accum[:] = grads
|
accum[:] = grads
|
||||||
if param < 0:
|
|
||||||
if param != -1:
|
|
||||||
accum *= -param
|
|
||||||
accum += grads
|
|
||||||
elif param == 1:
|
elif param == 1:
|
||||||
pass
|
pass
|
||||||
|
elif param == -1:
|
||||||
|
accum += grads
|
||||||
|
elif param < 0:
|
||||||
|
accum *= -param
|
||||||
|
accum += grads
|
||||||
else:
|
else:
|
||||||
accum += (1 - param) * (grads - accum)
|
accum += (1 - param) * (grads - accum)
|
||||||
return accum
|
return accum
|
||||||
|
|
Loading…
Reference in a new issue