diff --git a/onn/optimizer.py b/onn/optimizer.py index 3a42dac..c31843a 100644 --- a/onn/optimizer.py +++ b/onn/optimizer.py @@ -387,8 +387,8 @@ class Adamlike(Optimizer): # refer to the subclasses for details. # these defaults match Adam's. - def __init__(self, lr=0.001, b1=0.9, b2=0.999, - power=1/2, debias=True, runmax=False, eps=1e-8): + def __init__(self, lr=0.001, b1=0.9, b2=0.999, power=1/2, + debias=True, runmax=False, yogi=False, eps=1e-8): self.b1 = _f(b1) # decay term self.b2 = _f(b2) # decay term self.b1_t_default = _f(b1) # decay term power t @@ -396,6 +396,7 @@ class Adamlike(Optimizer): self.power = _f(power) self.debias = bool(debias) self.runmax = bool(runmax) + self.yogi = bool(yogi) self.eps = _f(eps) super().__init__(lr) @@ -420,7 +421,12 @@ class Adamlike(Optimizer): # keep local references of mt and vt to simplify # implementing all the variations of Adam later. mt = filter_gradients(self.mt, dW, self.b1) - vt = filter_gradients(self.vt, np.square(dW), self.b2) + if self.yogi: + g2 = np.square(dW) + vt = self.vt + vt -= (1 - self.b2) * np.sign(vt - g2) * g2 + else: + vt = filter_gradients(self.vt, np.square(dW), self.b2) if self.runmax: self.vtmax[:] = np.maximum(vt, self.vtmax) @@ -455,8 +461,8 @@ class Adagrad(Adamlike): # paper: https://web.stanford.edu/~jduchi/projects/DuchiHaSi11.pdf def __init__(self, lr=0.01, eps=1e-8): - super().__init__(lr=lr, b1=0.0, b2=-1.0, - power=1/2, debias=False, runmax=False, eps=eps) + super().__init__(lr=lr, b1=0.0, b2=-1.0, power=1/2, + debias=False, runmax=False, eps=eps) @property def g(self): @@ -471,8 +477,8 @@ class RMSprop(Adamlike): # slides: http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf def __init__(self, lr=0.001, mu=0.99, eps=1e-8): - super().__init__(lr=lr, b1=0.0, b2=mu, - power=1/2, debias=False, runmax=False, eps=eps) + super().__init__(lr=lr, b1=0.0, b2=mu, power=1/2, + debias=False, runmax=False, eps=eps) @property def mu(self): @@ -499,8 +505,18 @@ class Adam(Adamlike): def __init__(self, lr=0.001, b1=0.9, b2=0.999, debias=True, eps=1e-8): - super().__init__(lr=lr, b1=b1, b2=b2, - power=1/2, debias=debias, runmax=False, eps=eps) + super().__init__(lr=lr, b1=b1, b2=b2, power=1/2, + debias=debias, runmax=False, yogi=False, eps=eps) + + +class Yogi(Adamlike): + # paper: https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf + # based on Adam. this changes the filtering for vt. + + def __init__(self, lr=0.01, b1=0.9, b2=0.999, + debias=True, eps=1e-3): + super().__init__(lr=lr, b1=b1, b2=b2, power=1/2, + debias=debias, runmax=False, yogi=True, eps=eps) class AMSgrad(Adamlike): @@ -509,8 +525,8 @@ class AMSgrad(Adamlike): def __init__(self, lr=0.001, b1=0.9, b2=0.999, debias=True, eps=1e-8): - super().__init__(lr=lr, b1=b1, b2=b2, - power=1/2, debias=debias, runmax=True, eps=eps) + super().__init__(lr=lr, b1=b1, b2=b2, power=1/2, + debias=debias, runmax=True, yogi=False, eps=eps) class Padam(Adamlike): @@ -520,5 +536,5 @@ class Padam(Adamlike): def __init__(self, lr=0.1, b1=0.9, b2=0.999, power=1/8, debias=True, eps=1e-8): - super().__init__(lr=lr, b1=b1, b2=b2, - power=power, debias=debias, runmax=True, eps=eps) + super().__init__(lr=lr, b1=b1, b2=b2, power=power, + debias=debias, runmax=True, yogi=False, eps=eps)