remove cruft from YellowFin
i might just remove YellowFin itself because it isn't working for me.
This commit is contained in:
parent
2cf38d4ece
commit
e5fd937ef6
1 changed files with 17 additions and 43 deletions
60
onn.py
60
onn.py
|
@ -183,11 +183,11 @@ class MomentumClip(Optimizer):
|
|||
else:
|
||||
return -self.lr * self.accum
|
||||
|
||||
yfalt = True # use computations from https://gist.github.com/botev/f8b32c00eafee222e47393f7f0747666
|
||||
class YellowFin(Optimizer):
|
||||
# paper: https://arxiv.org/abs/1706.03471
|
||||
# knowyourmeme: http://cs.stanford.edu/~zjian/project/YellowFin/
|
||||
# author's implementation: https://github.com/JianGoForIt/YellowFin/blob/master/tuner_utils/yellowfin.py
|
||||
# code lifted: https://gist.github.com/botev/f8b32c00eafee222e47393f7f0747666
|
||||
|
||||
def __init__(self, lr=0.1, mu=0.0, beta=0.999, window_size=20,
|
||||
debias=True, clip=1.0):
|
||||
|
@ -226,44 +226,23 @@ class YellowFin(Optimizer):
|
|||
self.mu_lpf = 0
|
||||
|
||||
def get_lr_mu(self):
|
||||
if yfalt:
|
||||
p = (np.square(self.dist_avg) * np.square(self.h_min)) / (2 * self.g_var)
|
||||
w3 = p * (np.sqrt(0.25 + p / 27.0) - 0.5)
|
||||
w = np.power(w3, 1/3)
|
||||
y = w - p / (3 * w)
|
||||
sqrt_mu1 = y + 1
|
||||
p = (np.square(self.dist_avg) * np.square(self.h_min)) / (2 * self.g_var)
|
||||
w3 = p * (np.sqrt(0.25 + p / 27.0) - 0.5)
|
||||
w = np.power(w3, 1/3)
|
||||
y = w - p / (3 * w)
|
||||
sqrt_mu1 = y + 1
|
||||
|
||||
sqrt_h_min = np.sqrt(self.h_min)
|
||||
sqrt_h_max = np.sqrt(self.h_max)
|
||||
sqrt_mu2 = (sqrt_h_max - sqrt_h_min) / (sqrt_h_max + sqrt_h_min)
|
||||
sqrt_h_min = np.sqrt(self.h_min)
|
||||
sqrt_h_max = np.sqrt(self.h_max)
|
||||
sqrt_mu2 = (sqrt_h_max - sqrt_h_min) / (sqrt_h_max + sqrt_h_min)
|
||||
|
||||
sqrt_mu = max(sqrt_mu1, sqrt_mu2)
|
||||
if sqrt_mu2 > sqrt_mu1:
|
||||
print('note: taking dr calculation. something may have exploded.')
|
||||
sqrt_mu = max(sqrt_mu1, sqrt_mu2)
|
||||
if sqrt_mu2 > sqrt_mu1:
|
||||
print('note: taking dr calculation. something may have exploded.')
|
||||
|
||||
lr = np.square(1 - sqrt_mu) / self.h_min
|
||||
mu = np.square(sqrt_mu)
|
||||
return lr, mu
|
||||
|
||||
else:
|
||||
const_fact = np.square(self.dist_avg) * np.square(self.h_min) / 2 / self.g_var
|
||||
assert const_fact > -1e-7, "invalid factor"
|
||||
coef = [-1.0, 3.0, -(3.0 + const_fact), 1.0]
|
||||
roots = np.roots(coef) # note: returns a list of np.complex64.
|
||||
|
||||
roots = roots[np.logical_and(np.real(roots) > 0, np.real(roots) < 1)]
|
||||
root = roots[np.argmin(np.imag(roots))]
|
||||
assert np.absolute(root.imag) < 1e-5
|
||||
real_root = np.real(root)
|
||||
|
||||
dr_sqrt = np.sqrt(self.h_max / self.h_min)
|
||||
a, b = np.square((dr_sqrt - 1) / (dr_sqrt + 1)), np.square(real_root)
|
||||
mu = max(a, b)
|
||||
if a > b:
|
||||
print('note: taking dr calculation. something may have exploded.')
|
||||
lr_min = np.square(1 - np.sqrt(mu)) / self.h_min
|
||||
#lr_max = np.square(1 + np.sqrt(mu)) / self.h_max
|
||||
return lr_min, mu
|
||||
lr = np.square(1 - sqrt_mu) / self.h_min
|
||||
mu = np.square(sqrt_mu)
|
||||
return lr, mu
|
||||
|
||||
def compute(self, dW, W):
|
||||
if self.accum is None:
|
||||
|
@ -278,10 +257,6 @@ class YellowFin(Optimizer):
|
|||
#print("clipping gradients; norm: {:10.5f}".format(total_norm))
|
||||
dW *= clip_scale
|
||||
|
||||
if not yfalt:
|
||||
self.accum[:] = self.accum * self.mu + dW
|
||||
V = -self.lr * self.accum
|
||||
|
||||
#fmt = 'W std: {:10.7f}e-3, dWstd: {:10.7f}e-3, V std: {:10.7f}e-3'
|
||||
#print(fmt.format(np.std(W), np.std(dW) * 100, np.std(V) * 100))
|
||||
|
||||
|
@ -331,9 +306,8 @@ class YellowFin(Optimizer):
|
|||
self.mu = debias * self.mu_lpf
|
||||
self.lr = debias * self.lr_lpf
|
||||
|
||||
if yfalt:
|
||||
self.accum[:] = self.accum * self.mu - self.lr * dW
|
||||
V = self.accum
|
||||
self.accum[:] = self.accum * self.mu - self.lr * dW
|
||||
V = self.accum
|
||||
|
||||
self.step += 1
|
||||
self.beta_t *= self.beta
|
||||
|
|
Loading…
Reference in a new issue