finish(?) implementing YellowFin

This commit is contained in:
Connor Olding 2017-07-03 09:47:31 +00:00
parent c41700ab8d
commit 85c9b3b5c1
2 changed files with 94 additions and 90 deletions

169
onn.py
View File

@ -181,138 +181,134 @@ class MomentumClip(Optimizer):
else:
return -self.lr * self.accum
class YellowFin(Momentum):
class YellowFin(Optimizer):
# paper: https://arxiv.org/abs/1706.03471
# knowyourmeme: http://cs.stanford.edu/~zjian/project/YellowFin/
# author's implementation: https://github.com/JianGoForIt/YellowFin/blob/master/tuner_utils/yellowfin.py
def __init__(self, lr=0.1, mu=0.0, beta=0.999, curv_win_width=20):
def __init__(self, lr=0.1, mu=0.0, beta=0.999, window_size=20,
debias=True, clip=1.0):
self.lr_default = _f(lr)
self.mu_default = _f(mu)
self.beta = _f(beta)
self.curv_win_width = int(curv_win_width)
self.window_size = int(window_size) # curv_win_width
self.debias_enabled = bool(debias)
self.clip = _f(clip)
super().__init__(lr=lr, mu=mu, nesterov=False)
self.mu = _f(mu) # momentum
super().__init__(lr)
def reset(self):
super().reset()
self.accum = None
self.lr = self.lr_default
self.mu = self.mu_default
self.step = 0
self.beta_t = self.beta
self.curv_win = np.zeros([self.curv_win_width,], dtype=np.float32)
self.curv_win = np.zeros([self.window_size,], dtype=np.float32)
self.h_min = None
self.h_max = None
self.grad_norm_squared_lpf = 0
self.g_lpf = 0
#self.g_squared_lpf = 0
self.g_norm_squared_lpf = 0
self.g_norm_lpf = 0
self.h_min_lpf = 0
self.h_max_lpf = 0
self.grad_avg_lpf = 0
self.grad_avg_squared_lpf = 0
self.grad_norm_avg_lpf = 0
self.dist_to_opt_avg_lpf = 0
self.dist_lpf = 0
self.lr_lpf = 0
self.mu_lpf = 0
self.alpha_lpf = 0
def get_lr(self, mu_for_real):
return np.square(1 - np.sqrt(mu_for_real)) / self.h_min
return np.square(1 - np.sqrt(mu_for_real)) / self.h_min # lr_min
# np.square(1 + np.sqrt(mu_for_real)) / self.h_max # lr_max
def get_mu(self):
const_fact = np.square(self.dist_to_opt_avg) * np.square(self.h_min) / 2 / self.grad_var
#print('factor:', const_fact)
const_fact = np.square(self.dist_avg) * np.square(self.h_min) / 2 / self.g_var
assert const_fact > -1e-7, "invalid factor"
coef = _f([-1, 3, -(3 + const_fact), 1])
coef = [-1.0, 3.0, -(3.0 + const_fact), 1.0]
roots = np.roots(coef) # note: returns a list of np.complex64.
# filter out the correct root.
# we're looking for a momentum value,
# so it must be a real value within (0, 1).
# a tiny imaginary value is acceptable.
land = np.logical_and
root_idx = land(land(np.real(roots) > 0, np.real(roots) < 1), np.abs(np.imag(roots)) < 1e-5)
valid_roots = roots[np.where(root_idx)]
assert len(valid_roots) > 0, 'failed to find a valid root'
# there may be two valid, duplicate roots. select one.
real_root = np.real(valid_roots[0])
#print('selected root:', real_root)
roots = roots[np.logical_and(np.real(roots) > 0, np.real(roots) < 1)]
root = roots[np.argmin(np.imag(roots))]
assert np.absolute(root.imag) < 1e-5
real_root = np.real(root)
dr_sqrt = np.sqrt(self.h_max / self.h_min)
a, b = np.square(real_root), np.square((dr_sqrt - 1) / (dr_sqrt + 1))
assert self.h_max >= self.h_min
a, b = np.square((dr_sqrt - 1) / (dr_sqrt + 1)), np.square(real_root)
mu = max(a, b)
if b > a:
print('note: taking dr calculation')
#print('new momentum:', mu)
return _f(mu)
if a > b:
print('note: taking dr calculation. something may have exploded.')
return mu
def compute(self, dW, W):
# plain momentum (pseudo-code):
#return -alpha * dW + mu * (W - W_old)
V = super().compute(dW, W)
if self.accum is None:
self.accum = np.zeros_like(dW)
# TODO: prevent allocations everywhere by using [:].
# assuming that really works. i haven't actually checked.
total_norm = np.linalg.norm(dW)
clip_scale = self.clip / (total_norm + 1e-6)
if clip_scale < 1:
#print("clipping gradients; norm: {:10.7f}".format(total_norm))
dW *= clip_scale
self.accum[:] = self.accum * self.mu + dW
V = -self.lr * self.accum
#fmt = 'W std: {:10.7f}e-3, dWstd: {:10.7f}e-3, V std: {:10.7f}e-3'
#print(fmt.format(np.std(W), np.std(dW) * 100, np.std(V) * 100))
b = self.beta
m1b = 1 - self.beta
self.debias = 1 / (1 - self.beta_t)
#self.debias = _1
debias = 1 / (1 - self.beta_t) if self.debias_enabled else 1
# NOTE TO SELF: any time the reference code says "avg" they imply "lpf".
g = dW
g_squared = np.square(g)
g_norm_squared = np.sum(g_squared)
g_norm = np.sqrt(g_norm_squared)
grad_squared = dW * dW
grad_norm_squared = np.sum(grad_squared)
self.grad_norm_squared_lpf = b * self.grad_norm_squared_lpf + m1b * grad_norm_squared
grad_norm_squared_avg = self.grad_norm_squared_lpf * self.debias
# curvature_range()
self.curv_win[self.step % self.curv_win_width] = grad_norm_squared
# remember iterations (steps) start from 0.
valid_window = self.curv_win[:min(self.curv_win_width, self.step + 1)]
self.curv_win[self.step % self.window_size] = g_norm_squared
valid_window = self.curv_win[:min(self.window_size, self.step + 1)]
h_min_t = np.min(valid_window)
h_max_t = np.max(valid_window)
#print(h_min_t, h_max_t)
self.h_min_lpf = b * self.h_min_lpf + m1b * h_min_t
self.h_max_lpf = b * self.h_max_lpf + m1b * h_max_t
self.h_min = self.h_min_lpf * self.debias
self.h_max = self.h_max_lpf * self.debias
# FIXME? the first few iterations are huuuuuuuge for regression.
#print(self.h_min, self.h_max)
# grad_variance()
self.grad_avg_lpf = b * self.grad_avg_lpf + m1b * dW
self.grad_avg_squared_lpf = b * self.grad_avg_squared_lpf + m1b * grad_squared
self.grad_avg = self.grad_avg_lpf * self.debias
self.grad_avg_squared = self.grad_avg_squared_lpf * self.debias
# FIXME: reimplement, this is weird.
#self._grad_avg = [self._moving_averager.average(dW)]
#self._grad_avg_squared = [np.square(val) for val in self._grad_avg]
#self._grad_var = self._grad_norm_squared_avg - np.add_n( [np.reduce_sum(val) for val in self._grad_avg_squared] )
# || g^2_avg - g_avg^2 ||_1
#self.grad_var = grad_norm_squared_avg - np.sum(self.grad_avg_squared)
# note: the abs probably isn't necessary here.
self.grad_var = np.sum(np.abs(self.grad_avg_squared - np.square(self.grad_avg)))
#print(self.grad_var)
self.g_lpf = b * self.g_lpf + m1b * g
#self.g_squared_lpf = b * self.g_squared_lpf + m1b * g_squared
self.g_norm_squared_lpf = b * self.g_norm_squared_lpf + m1b * g_norm_squared
self.g_norm_lpf = b * self.g_norm_lpf + m1b * g_norm
self.h_min_lpf = b * self.h_min_lpf + m1b * h_min_t
self.h_max_lpf = b * self.h_max_lpf + m1b * h_max_t
# dist_to_opt()
grad_norm = np.sqrt(grad_norm_squared)
self.grad_norm_avg_lpf = b * self.grad_norm_avg_lpf + m1b * grad_norm
grad_norm_avg = self.grad_norm_avg_lpf * self.debias
# single iteration distance estimation.
dist_to_opt = grad_norm_avg / grad_norm_squared_avg
# running average of distance
self.dist_to_opt_avg_lpf = b * self.dist_to_opt_avg_lpf + m1b * dist_to_opt
self.dist_to_opt_avg = self.dist_to_opt_avg_lpf * self.debias
g_avg = debias * self.g_lpf
#g_squared_avg = debias * self.g_squared_lpf
g_norm_squared_avg = debias * self.g_norm_squared_lpf
g_norm_avg = debias * self.g_norm_lpf
self.h_min = debias * self.h_min_lpf
self.h_max = debias * self.h_max_lpf
dist = g_norm_avg / g_norm_squared_avg
self.dist_lpf = b * self.dist_lpf + m1b * dist
self.dist_avg = debias * self.dist_lpf
self.g_var = g_norm_squared_avg - np.sum(np.square(g_avg))
# alternatively:
#self.g_var = np.sum(np.abs(g_squared_avg - np.square(g_avg)))
# update_hyper_param()
if self.step > 0:
mu_for_real = self.get_mu()
lr_for_real = self.get_lr(mu_for_real)
self.mu_lpf = b * self.mu_lpf + m1b * mu_for_real
self.alpha_lpf = b * self.alpha_lpf + m1b * lr_for_real
self.mu = self.mu_lpf * self.debias
self.alpha = self.alpha_lpf * self.debias
###
self.lr_lpf = b * self.lr_lpf + m1b * lr_for_real
self.mu = debias * self.mu_lpf
self.lr = debias * self.lr_lpf
self.step += 1
self.beta_t *= self.beta
@ -854,6 +850,13 @@ def optim_from_config(config):
b1 = np.exp(-1/d1)
b2 = np.exp(-1/d2)
optim = FTML(b1=b1, b2=b2)
elif config.optim == 'yf':
d1 = config.optim_decay1 if 'optim_decay1' in config else 999.5
d2 = config.optim_decay2 if 'optim_decay2' in config else 999.5
if d1 != d2:
raise Exception("yellowfin only uses one decay term.")
beta = np.exp(-1/d1)
optim = YellowFin(beta=beta)
elif config.optim in ('rms', 'rmsprop'):
d2 = config.optim_decay2 if 'optim_decay2' in config else 99.5
mu = np.exp(-1/d2)

View File

@ -46,19 +46,19 @@ else:
learner_class = None #SGDR
restart_decay = 0.5
n_dense = 2
n_dense = 1
n_denses = 0
new_dims = (4, 12)
activation = Relu
activation = Relu # GeluApprox
reg = None # L1L2(3.2e-5, 3.2e-4)
final_reg = None # L1L2(3.2e-5, 1e-3)
reg = None # L1L2(1e-6, 1e-5) # L1L2(3.2e-5, 3.2e-4)
final_reg = None # L1L2(1e-6, 1e-5) # L1L2(3.2e-5, 1e-3)
dropout = None # 0.05
actreg_lamb = None #1e-4
load_fn = None
save_fn = 'mnist.h5'
log_fn = 'mnist_losses.npz'
save_fn = 'mnist3.h5'
log_fn = 'mnist_losses3.npz'
fn = 'mnist.npz'
mnist_dim = 28
@ -132,7 +132,8 @@ model = Model(x, y, unsafe=True)
lr *= np.sqrt(bs)
optim = YellowFin()
optim = YellowFin() #beta=np.exp(-1/240)
#optim = MomentumClip(0.8, 0.8)
if learner_class == SGDR:
learner = learner_class(optim, epochs=epochs//starts, rate=lr,
restarts=starts-1, restart_decay=restart_decay,