finish(?) implementing YellowFin
This commit is contained in:
parent
c41700ab8d
commit
85c9b3b5c1
165
onn.py
165
onn.py
|
@ -181,138 +181,134 @@ class MomentumClip(Optimizer):
|
||||||
else:
|
else:
|
||||||
return -self.lr * self.accum
|
return -self.lr * self.accum
|
||||||
|
|
||||||
class YellowFin(Momentum):
|
class YellowFin(Optimizer):
|
||||||
# paper: https://arxiv.org/abs/1706.03471
|
# paper: https://arxiv.org/abs/1706.03471
|
||||||
# knowyourmeme: http://cs.stanford.edu/~zjian/project/YellowFin/
|
# knowyourmeme: http://cs.stanford.edu/~zjian/project/YellowFin/
|
||||||
# author's implementation: https://github.com/JianGoForIt/YellowFin/blob/master/tuner_utils/yellowfin.py
|
# author's implementation: https://github.com/JianGoForIt/YellowFin/blob/master/tuner_utils/yellowfin.py
|
||||||
|
|
||||||
def __init__(self, lr=0.1, mu=0.0, beta=0.999, curv_win_width=20):
|
def __init__(self, lr=0.1, mu=0.0, beta=0.999, window_size=20,
|
||||||
|
debias=True, clip=1.0):
|
||||||
self.lr_default = _f(lr)
|
self.lr_default = _f(lr)
|
||||||
self.mu_default = _f(mu)
|
self.mu_default = _f(mu)
|
||||||
self.beta = _f(beta)
|
self.beta = _f(beta)
|
||||||
self.curv_win_width = int(curv_win_width)
|
self.window_size = int(window_size) # curv_win_width
|
||||||
|
self.debias_enabled = bool(debias)
|
||||||
|
self.clip = _f(clip)
|
||||||
|
|
||||||
super().__init__(lr=lr, mu=mu, nesterov=False)
|
self.mu = _f(mu) # momentum
|
||||||
|
super().__init__(lr)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
super().reset()
|
self.accum = None
|
||||||
|
|
||||||
self.lr = self.lr_default
|
self.lr = self.lr_default
|
||||||
self.mu = self.mu_default
|
self.mu = self.mu_default
|
||||||
|
|
||||||
self.step = 0
|
self.step = 0
|
||||||
self.beta_t = self.beta
|
self.beta_t = self.beta
|
||||||
|
|
||||||
self.curv_win = np.zeros([self.curv_win_width,], dtype=np.float32)
|
self.curv_win = np.zeros([self.window_size,], dtype=np.float32)
|
||||||
|
|
||||||
self.h_min = None
|
self.h_min = None
|
||||||
self.h_max = None
|
self.h_max = None
|
||||||
|
|
||||||
self.grad_norm_squared_lpf = 0
|
self.g_lpf = 0
|
||||||
|
#self.g_squared_lpf = 0
|
||||||
|
self.g_norm_squared_lpf = 0
|
||||||
|
self.g_norm_lpf = 0
|
||||||
self.h_min_lpf = 0
|
self.h_min_lpf = 0
|
||||||
self.h_max_lpf = 0
|
self.h_max_lpf = 0
|
||||||
self.grad_avg_lpf = 0
|
self.dist_lpf = 0
|
||||||
self.grad_avg_squared_lpf = 0
|
self.lr_lpf = 0
|
||||||
self.grad_norm_avg_lpf = 0
|
|
||||||
self.dist_to_opt_avg_lpf = 0
|
|
||||||
self.mu_lpf = 0
|
self.mu_lpf = 0
|
||||||
self.alpha_lpf = 0
|
|
||||||
|
|
||||||
def get_lr(self, mu_for_real):
|
def get_lr(self, mu_for_real):
|
||||||
return np.square(1 - np.sqrt(mu_for_real)) / self.h_min
|
return np.square(1 - np.sqrt(mu_for_real)) / self.h_min # lr_min
|
||||||
|
# np.square(1 + np.sqrt(mu_for_real)) / self.h_max # lr_max
|
||||||
|
|
||||||
def get_mu(self):
|
def get_mu(self):
|
||||||
const_fact = np.square(self.dist_to_opt_avg) * np.square(self.h_min) / 2 / self.grad_var
|
const_fact = np.square(self.dist_avg) * np.square(self.h_min) / 2 / self.g_var
|
||||||
#print('factor:', const_fact)
|
|
||||||
assert const_fact > -1e-7, "invalid factor"
|
assert const_fact > -1e-7, "invalid factor"
|
||||||
coef = _f([-1, 3, -(3 + const_fact), 1])
|
coef = [-1.0, 3.0, -(3.0 + const_fact), 1.0]
|
||||||
roots = np.roots(coef) # note: returns a list of np.complex64.
|
roots = np.roots(coef) # note: returns a list of np.complex64.
|
||||||
|
|
||||||
# filter out the correct root.
|
roots = roots[np.logical_and(np.real(roots) > 0, np.real(roots) < 1)]
|
||||||
# we're looking for a momentum value,
|
root = roots[np.argmin(np.imag(roots))]
|
||||||
# so it must be a real value within (0, 1).
|
assert np.absolute(root.imag) < 1e-5
|
||||||
# a tiny imaginary value is acceptable.
|
real_root = np.real(root)
|
||||||
land = np.logical_and
|
|
||||||
root_idx = land(land(np.real(roots) > 0, np.real(roots) < 1), np.abs(np.imag(roots)) < 1e-5)
|
|
||||||
valid_roots = roots[np.where(root_idx)]
|
|
||||||
assert len(valid_roots) > 0, 'failed to find a valid root'
|
|
||||||
# there may be two valid, duplicate roots. select one.
|
|
||||||
real_root = np.real(valid_roots[0])
|
|
||||||
#print('selected root:', real_root)
|
|
||||||
|
|
||||||
dr_sqrt = np.sqrt(self.h_max / self.h_min)
|
dr_sqrt = np.sqrt(self.h_max / self.h_min)
|
||||||
a, b = np.square(real_root), np.square((dr_sqrt - 1) / (dr_sqrt + 1))
|
assert self.h_max >= self.h_min
|
||||||
|
a, b = np.square((dr_sqrt - 1) / (dr_sqrt + 1)), np.square(real_root)
|
||||||
mu = max(a, b)
|
mu = max(a, b)
|
||||||
if b > a:
|
if a > b:
|
||||||
print('note: taking dr calculation')
|
print('note: taking dr calculation. something may have exploded.')
|
||||||
#print('new momentum:', mu)
|
return mu
|
||||||
return _f(mu)
|
|
||||||
|
|
||||||
def compute(self, dW, W):
|
def compute(self, dW, W):
|
||||||
# plain momentum (pseudo-code):
|
if self.accum is None:
|
||||||
#return -alpha * dW + mu * (W - W_old)
|
self.accum = np.zeros_like(dW)
|
||||||
V = super().compute(dW, W)
|
|
||||||
|
# TODO: prevent allocations everywhere by using [:].
|
||||||
|
# assuming that really works. i haven't actually checked.
|
||||||
|
|
||||||
|
total_norm = np.linalg.norm(dW)
|
||||||
|
clip_scale = self.clip / (total_norm + 1e-6)
|
||||||
|
if clip_scale < 1:
|
||||||
|
#print("clipping gradients; norm: {:10.7f}".format(total_norm))
|
||||||
|
dW *= clip_scale
|
||||||
|
|
||||||
|
self.accum[:] = self.accum * self.mu + dW
|
||||||
|
V = -self.lr * self.accum
|
||||||
|
|
||||||
|
#fmt = 'W std: {:10.7f}e-3, dWstd: {:10.7f}e-3, V std: {:10.7f}e-3'
|
||||||
|
#print(fmt.format(np.std(W), np.std(dW) * 100, np.std(V) * 100))
|
||||||
|
|
||||||
b = self.beta
|
b = self.beta
|
||||||
m1b = 1 - self.beta
|
m1b = 1 - self.beta
|
||||||
self.debias = 1 / (1 - self.beta_t)
|
debias = 1 / (1 - self.beta_t) if self.debias_enabled else 1
|
||||||
#self.debias = _1
|
|
||||||
|
|
||||||
# NOTE TO SELF: any time the reference code says "avg" they imply "lpf".
|
g = dW
|
||||||
|
g_squared = np.square(g)
|
||||||
|
g_norm_squared = np.sum(g_squared)
|
||||||
|
g_norm = np.sqrt(g_norm_squared)
|
||||||
|
|
||||||
grad_squared = dW * dW
|
self.curv_win[self.step % self.window_size] = g_norm_squared
|
||||||
grad_norm_squared = np.sum(grad_squared)
|
valid_window = self.curv_win[:min(self.window_size, self.step + 1)]
|
||||||
self.grad_norm_squared_lpf = b * self.grad_norm_squared_lpf + m1b * grad_norm_squared
|
|
||||||
grad_norm_squared_avg = self.grad_norm_squared_lpf * self.debias
|
|
||||||
|
|
||||||
# curvature_range()
|
|
||||||
self.curv_win[self.step % self.curv_win_width] = grad_norm_squared
|
|
||||||
# remember iterations (steps) start from 0.
|
|
||||||
valid_window = self.curv_win[:min(self.curv_win_width, self.step + 1)]
|
|
||||||
h_min_t = np.min(valid_window)
|
h_min_t = np.min(valid_window)
|
||||||
h_max_t = np.max(valid_window)
|
h_max_t = np.max(valid_window)
|
||||||
#print(h_min_t, h_max_t)
|
|
||||||
|
self.g_lpf = b * self.g_lpf + m1b * g
|
||||||
|
#self.g_squared_lpf = b * self.g_squared_lpf + m1b * g_squared
|
||||||
|
self.g_norm_squared_lpf = b * self.g_norm_squared_lpf + m1b * g_norm_squared
|
||||||
|
self.g_norm_lpf = b * self.g_norm_lpf + m1b * g_norm
|
||||||
self.h_min_lpf = b * self.h_min_lpf + m1b * h_min_t
|
self.h_min_lpf = b * self.h_min_lpf + m1b * h_min_t
|
||||||
self.h_max_lpf = b * self.h_max_lpf + m1b * h_max_t
|
self.h_max_lpf = b * self.h_max_lpf + m1b * h_max_t
|
||||||
self.h_min = self.h_min_lpf * self.debias
|
|
||||||
self.h_max = self.h_max_lpf * self.debias
|
|
||||||
# FIXME? the first few iterations are huuuuuuuge for regression.
|
|
||||||
#print(self.h_min, self.h_max)
|
|
||||||
|
|
||||||
# grad_variance()
|
g_avg = debias * self.g_lpf
|
||||||
self.grad_avg_lpf = b * self.grad_avg_lpf + m1b * dW
|
#g_squared_avg = debias * self.g_squared_lpf
|
||||||
self.grad_avg_squared_lpf = b * self.grad_avg_squared_lpf + m1b * grad_squared
|
g_norm_squared_avg = debias * self.g_norm_squared_lpf
|
||||||
self.grad_avg = self.grad_avg_lpf * self.debias
|
g_norm_avg = debias * self.g_norm_lpf
|
||||||
self.grad_avg_squared = self.grad_avg_squared_lpf * self.debias
|
self.h_min = debias * self.h_min_lpf
|
||||||
# FIXME: reimplement, this is weird.
|
self.h_max = debias * self.h_max_lpf
|
||||||
#self._grad_avg = [self._moving_averager.average(dW)]
|
|
||||||
#self._grad_avg_squared = [np.square(val) for val in self._grad_avg]
|
|
||||||
#self._grad_var = self._grad_norm_squared_avg - np.add_n( [np.reduce_sum(val) for val in self._grad_avg_squared] )
|
|
||||||
# || g^2_avg - g_avg^2 ||_1
|
|
||||||
#self.grad_var = grad_norm_squared_avg - np.sum(self.grad_avg_squared)
|
|
||||||
# note: the abs probably isn't necessary here.
|
|
||||||
self.grad_var = np.sum(np.abs(self.grad_avg_squared - np.square(self.grad_avg)))
|
|
||||||
#print(self.grad_var)
|
|
||||||
|
|
||||||
# dist_to_opt()
|
dist = g_norm_avg / g_norm_squared_avg
|
||||||
grad_norm = np.sqrt(grad_norm_squared)
|
|
||||||
self.grad_norm_avg_lpf = b * self.grad_norm_avg_lpf + m1b * grad_norm
|
self.dist_lpf = b * self.dist_lpf + m1b * dist
|
||||||
grad_norm_avg = self.grad_norm_avg_lpf * self.debias
|
|
||||||
# single iteration distance estimation.
|
self.dist_avg = debias * self.dist_lpf
|
||||||
dist_to_opt = grad_norm_avg / grad_norm_squared_avg
|
|
||||||
# running average of distance
|
self.g_var = g_norm_squared_avg - np.sum(np.square(g_avg))
|
||||||
self.dist_to_opt_avg_lpf = b * self.dist_to_opt_avg_lpf + m1b * dist_to_opt
|
# alternatively:
|
||||||
self.dist_to_opt_avg = self.dist_to_opt_avg_lpf * self.debias
|
#self.g_var = np.sum(np.abs(g_squared_avg - np.square(g_avg)))
|
||||||
|
|
||||||
# update_hyper_param()
|
|
||||||
if self.step > 0:
|
if self.step > 0:
|
||||||
mu_for_real = self.get_mu()
|
mu_for_real = self.get_mu()
|
||||||
lr_for_real = self.get_lr(mu_for_real)
|
lr_for_real = self.get_lr(mu_for_real)
|
||||||
self.mu_lpf = b * self.mu_lpf + m1b * mu_for_real
|
self.mu_lpf = b * self.mu_lpf + m1b * mu_for_real
|
||||||
self.alpha_lpf = b * self.alpha_lpf + m1b * lr_for_real
|
self.lr_lpf = b * self.lr_lpf + m1b * lr_for_real
|
||||||
self.mu = self.mu_lpf * self.debias
|
self.mu = debias * self.mu_lpf
|
||||||
self.alpha = self.alpha_lpf * self.debias
|
self.lr = debias * self.lr_lpf
|
||||||
|
|
||||||
###
|
|
||||||
|
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.beta_t *= self.beta
|
self.beta_t *= self.beta
|
||||||
|
@ -854,6 +850,13 @@ def optim_from_config(config):
|
||||||
b1 = np.exp(-1/d1)
|
b1 = np.exp(-1/d1)
|
||||||
b2 = np.exp(-1/d2)
|
b2 = np.exp(-1/d2)
|
||||||
optim = FTML(b1=b1, b2=b2)
|
optim = FTML(b1=b1, b2=b2)
|
||||||
|
elif config.optim == 'yf':
|
||||||
|
d1 = config.optim_decay1 if 'optim_decay1' in config else 999.5
|
||||||
|
d2 = config.optim_decay2 if 'optim_decay2' in config else 999.5
|
||||||
|
if d1 != d2:
|
||||||
|
raise Exception("yellowfin only uses one decay term.")
|
||||||
|
beta = np.exp(-1/d1)
|
||||||
|
optim = YellowFin(beta=beta)
|
||||||
elif config.optim in ('rms', 'rmsprop'):
|
elif config.optim in ('rms', 'rmsprop'):
|
||||||
d2 = config.optim_decay2 if 'optim_decay2' in config else 99.5
|
d2 = config.optim_decay2 if 'optim_decay2' in config else 99.5
|
||||||
mu = np.exp(-1/d2)
|
mu = np.exp(-1/d2)
|
||||||
|
|
15
onn_mnist.py
15
onn_mnist.py
|
@ -46,19 +46,19 @@ else:
|
||||||
learner_class = None #SGDR
|
learner_class = None #SGDR
|
||||||
restart_decay = 0.5
|
restart_decay = 0.5
|
||||||
|
|
||||||
n_dense = 2
|
n_dense = 1
|
||||||
n_denses = 0
|
n_denses = 0
|
||||||
new_dims = (4, 12)
|
new_dims = (4, 12)
|
||||||
activation = Relu
|
activation = Relu # GeluApprox
|
||||||
|
|
||||||
reg = None # L1L2(3.2e-5, 3.2e-4)
|
reg = None # L1L2(1e-6, 1e-5) # L1L2(3.2e-5, 3.2e-4)
|
||||||
final_reg = None # L1L2(3.2e-5, 1e-3)
|
final_reg = None # L1L2(1e-6, 1e-5) # L1L2(3.2e-5, 1e-3)
|
||||||
dropout = None # 0.05
|
dropout = None # 0.05
|
||||||
actreg_lamb = None #1e-4
|
actreg_lamb = None #1e-4
|
||||||
|
|
||||||
load_fn = None
|
load_fn = None
|
||||||
save_fn = 'mnist.h5'
|
save_fn = 'mnist3.h5'
|
||||||
log_fn = 'mnist_losses.npz'
|
log_fn = 'mnist_losses3.npz'
|
||||||
|
|
||||||
fn = 'mnist.npz'
|
fn = 'mnist.npz'
|
||||||
mnist_dim = 28
|
mnist_dim = 28
|
||||||
|
@ -132,7 +132,8 @@ model = Model(x, y, unsafe=True)
|
||||||
|
|
||||||
lr *= np.sqrt(bs)
|
lr *= np.sqrt(bs)
|
||||||
|
|
||||||
optim = YellowFin()
|
optim = YellowFin() #beta=np.exp(-1/240)
|
||||||
|
#optim = MomentumClip(0.8, 0.8)
|
||||||
if learner_class == SGDR:
|
if learner_class == SGDR:
|
||||||
learner = learner_class(optim, epochs=epochs//starts, rate=lr,
|
learner = learner_class(optim, epochs=epochs//starts, rate=lr,
|
||||||
restarts=starts-1, restart_decay=restart_decay,
|
restarts=starts-1, restart_decay=restart_decay,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user