add MSVAG optimizer
This commit is contained in:
parent
b3b82ca4f0
commit
2e80f8b1a7
2 changed files with 59 additions and 0 deletions
|
@ -587,6 +587,54 @@ class Padam(Adamlike):
|
|||
debias=debias, runmax=True, yogi=False, eps=eps)
|
||||
|
||||
|
||||
class MSVAG(Optimizer):
|
||||
# paper: https://arxiv.org/abs/1705.07774
|
||||
# this is the variance-reducing aspect isolated from the rest of Adam.
|
||||
|
||||
def __init__(self, lr=0.1, b=0.99):
|
||||
self.b = _f(b)
|
||||
super().__init__(lr=lr)
|
||||
|
||||
def reset(self):
|
||||
self.mt = None
|
||||
self.vt = None
|
||||
self.bt = self.b
|
||||
|
||||
super().reset()
|
||||
|
||||
def compute(self, dW, W):
|
||||
if self.mt is None:
|
||||
self.mt = np.zeros_like(dW)
|
||||
if self.vt is None:
|
||||
self.vt = np.zeros_like(dW)
|
||||
|
||||
mt = filter_gradients(self.mt, dW, self.b)
|
||||
vt = filter_gradients(self.vt, np.square(dW), self.b)
|
||||
|
||||
# debiasing:
|
||||
if self.bt != 1:
|
||||
mt = mt / (1 - self.bt)
|
||||
vt = vt / (1 - self.bt)
|
||||
num = (1 - self.b) * (1 + self.bt)
|
||||
den = (1 + self.b) * (1 - self.bt)
|
||||
rho = num / den
|
||||
else:
|
||||
# technically, this should be 1 / (t + 1),
|
||||
# but we don't keep track of t directly.
|
||||
rho = 1
|
||||
|
||||
if rho != 1:
|
||||
mt2 = np.square(mt)
|
||||
s = (vt - mt2) / (1 - rho)
|
||||
gamma = div0(mt2, mt2 + rho * s)
|
||||
else:
|
||||
gamma = 1
|
||||
|
||||
self.bt *= self.b
|
||||
|
||||
return -self.lr * (gamma * mt)
|
||||
|
||||
|
||||
AMSGrad = AMSgrad
|
||||
AdaDelta = Adadelta
|
||||
AdaGrad = Adagrad
|
||||
|
|
|
@ -29,6 +29,17 @@ def lower_priority():
|
|||
os.nice(1)
|
||||
|
||||
|
||||
def div0(a, b):
|
||||
"""division, whereby division by zero equals zero"""
|
||||
# http://stackoverflow.com/a/35696047
|
||||
a = np.asanyarray(a)
|
||||
b = np.asanyarray(b)
|
||||
with np.errstate(divide='ignore', invalid='ignore'):
|
||||
c = np.true_divide(a, b)
|
||||
c[~np.isfinite(c)] = 0 # -inf inf NaN
|
||||
return c
|
||||
|
||||
|
||||
def onehot(y):
|
||||
unique = np.unique(y)
|
||||
Y = np.zeros((y.shape[0], len(unique)), dtype=np.int8)
|
||||
|
|
Loading…
Reference in a new issue