add WIP YellowFin optimizer implementation
This commit is contained in:
parent
d8bf6d1c5b
commit
9706aaabbb
2 changed files with 142 additions and 5 deletions
137
onn.py
137
onn.py
|
@ -154,6 +154,143 @@ class FTML(Optimizer):
|
|||
# subtract by weights to avoid having to override self.update.
|
||||
return -self.zt / self.dt - W
|
||||
|
||||
class YellowFin(Momentum):
|
||||
# paper: https://arxiv.org/abs/1706.03471
|
||||
# knowyourmeme: http://cs.stanford.edu/~zjian/project/YellowFin/
|
||||
# author's implementation: https://github.com/JianGoForIt/YellowFin/blob/master/tuner_utils/yellowfin.py
|
||||
|
||||
def __init__(self, alpha=0.1, mu=0.0, beta=0.999, curv_win_width=20):
|
||||
self.alpha_default = _f(alpha)
|
||||
self.mu_default = _f(mu)
|
||||
self.beta = _f(beta)
|
||||
self.curv_win_width = int(curv_win_width)
|
||||
|
||||
super().__init__(alpha=alpha, mu=mu, nesterov=False)
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
self.alpha = self.alpha_default
|
||||
self.mu = self.mu_default
|
||||
|
||||
self.step = 0
|
||||
self.beta_t = self.beta
|
||||
|
||||
self.curv_win = np.zeros([self.curv_win_width,], dtype=np.float32)
|
||||
|
||||
self.h_min = None
|
||||
self.h_max = None
|
||||
|
||||
self.grad_norm_squared_lpf = 0
|
||||
self.h_min_lpf = 0
|
||||
self.h_max_lpf = 0
|
||||
self.grad_avg_lpf = 0
|
||||
self.grad_avg_squared_lpf = 0
|
||||
self.grad_norm_avg_lpf = 0
|
||||
self.dist_to_opt_avg_lpf = 0
|
||||
self.mu_lpf = 0
|
||||
self.alpha_lpf = 0
|
||||
|
||||
def get_lr(self, mu_for_real):
|
||||
return np.square(1 - np.sqrt(mu_for_real)) / self.h_min
|
||||
|
||||
def get_mu(self):
|
||||
const_fact = np.square(self.dist_to_opt_avg) * np.square(self.h_min) / 2 / self.grad_var
|
||||
#print('factor:', const_fact)
|
||||
assert const_fact > -1e-7, "invalid factor"
|
||||
coef = _f([-1, 3, -(3 + const_fact), 1])
|
||||
roots = np.roots(coef) # note: returns a list of np.complex64.
|
||||
|
||||
# filter out the correct root.
|
||||
# we're looking for a momentum value,
|
||||
# so it must be a real value within (0, 1).
|
||||
# a tiny imaginary value is acceptable.
|
||||
land = np.logical_and
|
||||
root_idx = land(land(np.real(roots) > 0, np.real(roots) < 1), np.abs(np.imag(roots)) < 1e-5)
|
||||
valid_roots = roots[np.where(root_idx)]
|
||||
assert len(valid_roots) > 0, 'failed to find a valid root'
|
||||
# there may be two valid, duplicate roots. select one.
|
||||
real_root = np.real(valid_roots[0])
|
||||
#print('selected root:', real_root)
|
||||
|
||||
dr_sqrt = np.sqrt(self.h_max / self.h_min)
|
||||
a, b = np.square(real_root), np.square((dr_sqrt - 1) / (dr_sqrt + 1))
|
||||
mu = max(a, b)
|
||||
if b > a:
|
||||
print('note: taking dr calculation')
|
||||
#print('new momentum:', mu)
|
||||
return _f(mu)
|
||||
|
||||
def compute(self, dW, W):
|
||||
# plain momentum (pseudo-code):
|
||||
#return -alpha * dW + mu * (W - W_old)
|
||||
V = super().compute(dW, W)
|
||||
|
||||
b = self.beta
|
||||
m1b = 1 - self.beta
|
||||
self.debias = 1 / (1 - self.beta_t)
|
||||
#self.debias = _1
|
||||
|
||||
# NOTE TO SELF: any time the reference code says "avg" they imply "lpf".
|
||||
|
||||
grad_squared = dW * dW
|
||||
grad_norm_squared = np.sum(grad_squared)
|
||||
self.grad_norm_squared_lpf = b * self.grad_norm_squared_lpf + m1b * grad_norm_squared
|
||||
grad_norm_squared_avg = self.grad_norm_squared_lpf * self.debias
|
||||
|
||||
# curvature_range()
|
||||
self.curv_win[self.step % self.curv_win_width] = grad_norm_squared
|
||||
# remember iterations (steps) start from 0.
|
||||
valid_window = self.curv_win[:min(self.curv_win_width, self.step + 1)]
|
||||
h_min_t = np.min(valid_window)
|
||||
h_max_t = np.max(valid_window)
|
||||
#print(h_min_t, h_max_t)
|
||||
self.h_min_lpf = b * self.h_min_lpf + m1b * h_min_t
|
||||
self.h_max_lpf = b * self.h_max_lpf + m1b * h_max_t
|
||||
self.h_min = self.h_min_lpf * self.debias
|
||||
self.h_max = self.h_max_lpf * self.debias
|
||||
# FIXME? the first few iterations are huuuuuuuge for regression.
|
||||
#print(self.h_min, self.h_max)
|
||||
|
||||
# grad_variance()
|
||||
self.grad_avg_lpf = b * self.grad_avg_lpf + m1b * dW
|
||||
self.grad_avg_squared_lpf = b * self.grad_avg_squared_lpf + m1b * grad_squared
|
||||
self.grad_avg = self.grad_avg_lpf * self.debias
|
||||
self.grad_avg_squared = self.grad_avg_squared_lpf * self.debias
|
||||
# FIXME: reimplement, this is weird.
|
||||
#self._grad_avg = [self._moving_averager.average(dW)]
|
||||
#self._grad_avg_squared = [np.square(val) for val in self._grad_avg]
|
||||
#self._grad_var = self._grad_norm_squared_avg - np.add_n( [np.reduce_sum(val) for val in self._grad_avg_squared] )
|
||||
# || g^2_avg - g_avg^2 ||_1
|
||||
#self.grad_var = grad_norm_squared_avg - np.sum(self.grad_avg_squared)
|
||||
# note: the abs probably isn't necessary here.
|
||||
self.grad_var = np.sum(np.abs(self.grad_avg_squared - np.square(self.grad_avg)))
|
||||
#print(self.grad_var)
|
||||
|
||||
# dist_to_opt()
|
||||
grad_norm = np.sqrt(grad_norm_squared)
|
||||
self.grad_norm_avg_lpf = b * self.grad_norm_avg_lpf + m1b * grad_norm
|
||||
grad_norm_avg = self.grad_norm_avg_lpf * self.debias
|
||||
# single iteration distance estimation.
|
||||
dist_to_opt = grad_norm_avg / grad_norm_squared_avg
|
||||
# running average of distance
|
||||
self.dist_to_opt_avg_lpf = b * self.dist_to_opt_avg_lpf + m1b * dist_to_opt
|
||||
self.dist_to_opt_avg = self.dist_to_opt_avg_lpf * self.debias
|
||||
|
||||
# update_hyper_param()
|
||||
if self.step > 0:
|
||||
mu_for_real = self.get_mu()
|
||||
lr_for_real = self.get_lr(mu_for_real)
|
||||
self.mu_lpf = b * self.mu_lpf + m1b * mu_for_real
|
||||
self.alpha_lpf = b * self.alpha_lpf + m1b * lr_for_real
|
||||
self.mu = self.mu_lpf * self.debias
|
||||
self.alpha = self.alpha_lpf * self.debias
|
||||
|
||||
###
|
||||
|
||||
self.step += 1
|
||||
self.beta_t *= self.beta
|
||||
return V
|
||||
|
||||
# Nonparametric Layers {{{1
|
||||
|
||||
class AlphaDropout(Layer):
|
||||
|
|
10
onn_mnist.py
10
onn_mnist.py
|
@ -43,7 +43,7 @@ else:
|
|||
starts = 3
|
||||
bs = 500
|
||||
|
||||
learner_class = SGDR
|
||||
learner_class = None #SGDR
|
||||
restart_decay = 0.5
|
||||
|
||||
n_dense = 2
|
||||
|
@ -51,9 +51,9 @@ else:
|
|||
new_dims = (4, 12)
|
||||
activation = Relu
|
||||
|
||||
reg = L1L2(3.2e-5, 3.2e-4)
|
||||
final_reg = L1L2(3.2e-5, 1e-3)
|
||||
dropout = 0.05
|
||||
reg = None # L1L2(3.2e-5, 3.2e-4)
|
||||
final_reg = None # L1L2(3.2e-5, 1e-3)
|
||||
dropout = None # 0.05
|
||||
actreg_lamb = None #1e-4
|
||||
|
||||
load_fn = None
|
||||
|
@ -132,7 +132,7 @@ model = Model(x, y, unsafe=True)
|
|||
|
||||
lr *= np.sqrt(bs)
|
||||
|
||||
optim = Adam()
|
||||
optim = YellowFin()
|
||||
if learner_class == SGDR:
|
||||
learner = learner_class(optim, epochs=epochs//starts, rate=lr,
|
||||
restarts=starts-1, restart_decay=restart_decay,
|
||||
|
|
Loading…
Reference in a new issue