generalize Adam-like optimizers
This commit is contained in:
parent
c6ebd02ea9
commit
f60535aa01
1 changed files with 158 additions and 2 deletions
160
onn/optimizer.py
160
onn/optimizer.py
|
@ -4,8 +4,23 @@ from .float import _f, _0, _1
|
||||||
from .optimizer_base import *
|
from .optimizer_base import *
|
||||||
from .utility import *
|
from .utility import *
|
||||||
|
|
||||||
# some of the the following optimizers are blatantly lifted from tiny-dnn:
|
|
||||||
# https://github.com/tiny-dnn/tiny-dnn/blob/master/tiny_dnn/optimizers/optimizer.h
|
def filter_gradients(accum, grads, param):
|
||||||
|
# NOTE: this modifies accum in-place.
|
||||||
|
# param > 0 acts as a simple one-pole low-pass filter, unity at DC.
|
||||||
|
# param < 0 acts as an accumulator with a decay of -param, nonunity at DC.
|
||||||
|
# param == 0 simply copies grads into accum.
|
||||||
|
if param == 0:
|
||||||
|
accum[:] = grads
|
||||||
|
if param < 0:
|
||||||
|
if param != -1:
|
||||||
|
accum *= -param
|
||||||
|
accum += grads
|
||||||
|
elif param == 1:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
accum += (1 - param) * (grads - accum)
|
||||||
|
return accum
|
||||||
|
|
||||||
|
|
||||||
class Momentum(Optimizer):
|
class Momentum(Optimizer):
|
||||||
|
@ -598,3 +613,144 @@ class AMSgrad(Optimizer):
|
||||||
self.b2_t *= self.b2
|
self.b2_t *= self.b2
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class Adamlike(Optimizer):
|
||||||
|
# this generalizes a lot of algorithms that are
|
||||||
|
# either subsets or supersets of the Adam optimizer.
|
||||||
|
# refer to the subclasses for details.
|
||||||
|
|
||||||
|
# the arguments to init default to Adam's.
|
||||||
|
def __init__(self, lr=0.001, b1=0.9, b2=0.999,
|
||||||
|
power=1/2, debias=True, runmax=False, eps=1e-8):
|
||||||
|
self.b1 = _f(b1) # decay term
|
||||||
|
self.b2 = _f(b2) # decay term
|
||||||
|
self.b1_t_default = _f(b1) # decay term power t
|
||||||
|
self.b2_t_default = _f(b2) # decay term power t
|
||||||
|
self.power = _f(power)
|
||||||
|
self.debias = bool(debias)
|
||||||
|
self.runmax = bool(runmax)
|
||||||
|
self.eps = _f(eps)
|
||||||
|
|
||||||
|
super().__init__(lr)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.mt = None
|
||||||
|
self.vt = None
|
||||||
|
self.vtmax = None
|
||||||
|
self.b1_t = self.b1_t_default
|
||||||
|
self.b2_t = self.b2_t_default
|
||||||
|
|
||||||
|
def compute(self, dW, W):
|
||||||
|
if self.mt is None:
|
||||||
|
self.mt = np.zeros_like(dW)
|
||||||
|
if self.vt is None:
|
||||||
|
self.vt = np.zeros_like(dW)
|
||||||
|
if self.vtmax is None and self.runmax:
|
||||||
|
self.vtmax = np.zeros_like(dW)
|
||||||
|
|
||||||
|
# keep local references of mt and vt to simplify
|
||||||
|
# implementing all the variations of Adam later.
|
||||||
|
mt = filter_gradients(self.mt, dW, self.b1)
|
||||||
|
vt = filter_gradients(self.vt, np.square(dW), self.b2)
|
||||||
|
|
||||||
|
if self.runmax:
|
||||||
|
self.vtmax[:] = np.maximum(vt, self.vtmax)
|
||||||
|
vt = self.vtmax
|
||||||
|
|
||||||
|
if self.debias:
|
||||||
|
if self.b1_t != 1:
|
||||||
|
mt = mt / (1 - self.b1_t)
|
||||||
|
if self.b2_t != 1:
|
||||||
|
vt = vt / (1 - self.b2_t)
|
||||||
|
|
||||||
|
if self.power == 0:
|
||||||
|
delta = mt
|
||||||
|
elif self.power == 1:
|
||||||
|
delta = mt / (vt + self.eps)
|
||||||
|
elif self.power == 1/2: # TODO: is this actually faster?
|
||||||
|
delta = mt / (np.sqrt(vt) + self.eps)
|
||||||
|
elif self.power == 1/3: # TODO: is this actually faster?
|
||||||
|
delta = mt / (np.cbrt(vt) + self.eps)
|
||||||
|
else:
|
||||||
|
delta = mt / (vt**self.power + self.eps)
|
||||||
|
|
||||||
|
if self.debias:
|
||||||
|
# decay gain.
|
||||||
|
self.b1_t *= self.b1
|
||||||
|
self.b2_t *= self.b2
|
||||||
|
|
||||||
|
return -self.lr * delta
|
||||||
|
|
||||||
|
|
||||||
|
class Adagrad(Adamlike):
|
||||||
|
# paper: https://web.stanford.edu/~jduchi/projects/DuchiHaSi11.pdf
|
||||||
|
|
||||||
|
def __init__(self, lr=0.01, eps=1e-8):
|
||||||
|
super().__init__(lr=lr, b1=0.0, b2=-1.0,
|
||||||
|
power=1/2, debias=False, runmax=False, eps=eps)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def g(self):
|
||||||
|
return self.vt
|
||||||
|
|
||||||
|
@g.setter
|
||||||
|
def g(self, value):
|
||||||
|
self.vt = value
|
||||||
|
|
||||||
|
|
||||||
|
class RMSprop(Adamlike):
|
||||||
|
# slides: http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
|
||||||
|
|
||||||
|
def __init__(self, lr=0.001, mu=0.99, eps=1e-8):
|
||||||
|
super().__init__(lr=lr, b1=0.0, b2=mu,
|
||||||
|
power=1/2, debias=False, runmax=False, eps=eps)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mu(self):
|
||||||
|
return self.b2
|
||||||
|
|
||||||
|
@mu.setter
|
||||||
|
def mu(self, value):
|
||||||
|
self.b2 = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def g(self):
|
||||||
|
return self.vt
|
||||||
|
|
||||||
|
@g.setter
|
||||||
|
def g(self, value):
|
||||||
|
self.vt = value
|
||||||
|
|
||||||
|
|
||||||
|
class Adam(Adamlike):
|
||||||
|
# paper: https://arxiv.org/abs/1412.6980
|
||||||
|
# Adam generalizes RMSprop, and
|
||||||
|
# adds a decay term to the regular (non-squared) delta, and performs
|
||||||
|
# debiasing to compensate for the filtered deltas starting from zero.
|
||||||
|
|
||||||
|
def __init__(self, lr=0.001, b1=0.9, b2=0.999,
|
||||||
|
debias=True, eps=1e-8):
|
||||||
|
super().__init__(lr=lr, b1=b1, b2=b2,
|
||||||
|
power=1/2, debias=debias, runmax=False, eps=eps)
|
||||||
|
|
||||||
|
|
||||||
|
class AMSgrad(Adamlike):
|
||||||
|
# paper: https://openreview.net/forum?id=ryQu7f-RZ
|
||||||
|
# based on Adam. this simply adds a running element-wise maximum to vt.
|
||||||
|
|
||||||
|
def __init__(self, lr=0.001, b1=0.9, b2=0.999,
|
||||||
|
debias=True, eps=1e-8):
|
||||||
|
super().__init__(lr=lr, b1=b1, b2=b2,
|
||||||
|
power=1/2, debias=debias, runmax=True, eps=eps)
|
||||||
|
|
||||||
|
|
||||||
|
class Padam(Adamlike):
|
||||||
|
# paper: https://arxiv.org/abs/1806.06763
|
||||||
|
# paper: https://arxiv.org/abs/1808.05671
|
||||||
|
# based on AMSgrad. this configures the power of vt to be closer to zero.
|
||||||
|
|
||||||
|
def __init__(self, lr=0.1, b1=0.9, b2=0.999,
|
||||||
|
power=1/8, debias=True, eps=1e-8):
|
||||||
|
super().__init__(lr=lr, b1=b1, b2=b2,
|
||||||
|
power=power, debias=debias, runmax=True, eps=eps)
|
||||||
|
|
Loading…
Add table
Reference in a new issue