reset learning rates in optimizers

This commit is contained in:
Connor Olding 2019-02-05 04:15:28 +01:00
parent bd07d983be
commit 7227559912
2 changed files with 22 additions and 1 deletions

View file

@ -34,6 +34,8 @@ class Momentum(Optimizer):
def reset(self):
self.Vprev = None
super().reset()
def compute(self, dW, W):
if self.Vprev is None:
self.Vprev = np.copy(dW)
@ -59,6 +61,8 @@ class Adadelta(Optimizer):
self.g = None
self.x = None
super().reset()
def compute(self, dW, W):
if self.g is None:
self.g = np.zeros_like(dW)
@ -88,6 +92,8 @@ class RMSpropCentered(Optimizer):
self.vt = None
self.delta = None
super().reset()
def compute(self, dW, W):
if self.g is None:
self.g = np.zeros_like(dW)
@ -136,6 +142,8 @@ class Nadam(Optimizer):
self.t = 0
self.sched = 1
super().reset()
def compute(self, dW, W):
self.t += 1
@ -184,6 +192,8 @@ class FTML(Optimizer):
self.b1_t = _1
self.b2_t = _1
super().reset()
def compute(self, dW, W):
if self.dt1 is None:
self.dt1 = np.zeros_like(dW)
@ -231,6 +241,8 @@ class MomentumClip(Optimizer):
def reset(self):
self.accum = None
super().reset()
def compute(self, dW, W):
if self.accum is None:
self.accum = np.zeros_like(dW)
@ -261,6 +273,8 @@ class AddSign(Optimizer):
def reset(self):
self.accum = None
super().reset()
def compute(self, dW, W):
if self.accum is None:
self.accum = np.zeros_like(dW)
@ -286,6 +300,8 @@ class PowerSign(Optimizer):
def reset(self):
self.accum = None
super().reset()
def compute(self, dW, W):
if self.accum is None:
self.accum = np.zeros_like(dW)
@ -326,6 +342,8 @@ class Neumann(Optimizer):
self.vt = None # weight accumulator.
self.t = 0
super().reset()
def compute(self, dW, W):
raise Exception("compute() is not available for this Optimizer.")
@ -389,6 +407,8 @@ class Adamlike(Optimizer):
self.b1_t = self.b1_t_default
self.b2_t = self.b2_t_default
super().reset()
def compute(self, dW, W):
if self.mt is None:
self.mt = np.zeros_like(dW)

View file

@ -6,10 +6,11 @@ from .float import _f
class Optimizer:
def __init__(self, lr=0.1):
self.lr = _f(lr) # learning rate
self.base_rate = self.lr
self.reset()
def reset(self):
pass
self.lr = self.base_rate
def compute(self, dW, W):
return -self.lr * dW