From 9706aaabbbfe0075159095d387df62edf7fbf666 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Sun, 2 Jul 2017 02:55:19 +0000 Subject: [PATCH] add WIP YellowFin optimizer implementation --- onn.py | 137 +++++++++++++++++++++++++++++++++++++++++++++++++++ onn_mnist.py | 10 ++-- 2 files changed, 142 insertions(+), 5 deletions(-) diff --git a/onn.py b/onn.py index 8b9480f..409f503 100755 --- a/onn.py +++ b/onn.py @@ -154,6 +154,143 @@ class FTML(Optimizer): # subtract by weights to avoid having to override self.update. return -self.zt / self.dt - W +class YellowFin(Momentum): + # paper: https://arxiv.org/abs/1706.03471 + # knowyourmeme: http://cs.stanford.edu/~zjian/project/YellowFin/ + # author's implementation: https://github.com/JianGoForIt/YellowFin/blob/master/tuner_utils/yellowfin.py + + def __init__(self, alpha=0.1, mu=0.0, beta=0.999, curv_win_width=20): + self.alpha_default = _f(alpha) + self.mu_default = _f(mu) + self.beta = _f(beta) + self.curv_win_width = int(curv_win_width) + + super().__init__(alpha=alpha, mu=mu, nesterov=False) + + def reset(self): + super().reset() + self.alpha = self.alpha_default + self.mu = self.mu_default + + self.step = 0 + self.beta_t = self.beta + + self.curv_win = np.zeros([self.curv_win_width,], dtype=np.float32) + + self.h_min = None + self.h_max = None + + self.grad_norm_squared_lpf = 0 + self.h_min_lpf = 0 + self.h_max_lpf = 0 + self.grad_avg_lpf = 0 + self.grad_avg_squared_lpf = 0 + self.grad_norm_avg_lpf = 0 + self.dist_to_opt_avg_lpf = 0 + self.mu_lpf = 0 + self.alpha_lpf = 0 + + def get_lr(self, mu_for_real): + return np.square(1 - np.sqrt(mu_for_real)) / self.h_min + + def get_mu(self): + const_fact = np.square(self.dist_to_opt_avg) * np.square(self.h_min) / 2 / self.grad_var + #print('factor:', const_fact) + assert const_fact > -1e-7, "invalid factor" + coef = _f([-1, 3, -(3 + const_fact), 1]) + roots = np.roots(coef) # note: returns a list of np.complex64. + + # filter out the correct root. + # we're looking for a momentum value, + # so it must be a real value within (0, 1). + # a tiny imaginary value is acceptable. + land = np.logical_and + root_idx = land(land(np.real(roots) > 0, np.real(roots) < 1), np.abs(np.imag(roots)) < 1e-5) + valid_roots = roots[np.where(root_idx)] + assert len(valid_roots) > 0, 'failed to find a valid root' + # there may be two valid, duplicate roots. select one. + real_root = np.real(valid_roots[0]) + #print('selected root:', real_root) + + dr_sqrt = np.sqrt(self.h_max / self.h_min) + a, b = np.square(real_root), np.square((dr_sqrt - 1) / (dr_sqrt + 1)) + mu = max(a, b) + if b > a: + print('note: taking dr calculation') + #print('new momentum:', mu) + return _f(mu) + + def compute(self, dW, W): + # plain momentum (pseudo-code): + #return -alpha * dW + mu * (W - W_old) + V = super().compute(dW, W) + + b = self.beta + m1b = 1 - self.beta + self.debias = 1 / (1 - self.beta_t) + #self.debias = _1 + + # NOTE TO SELF: any time the reference code says "avg" they imply "lpf". + + grad_squared = dW * dW + grad_norm_squared = np.sum(grad_squared) + self.grad_norm_squared_lpf = b * self.grad_norm_squared_lpf + m1b * grad_norm_squared + grad_norm_squared_avg = self.grad_norm_squared_lpf * self.debias + + # curvature_range() + self.curv_win[self.step % self.curv_win_width] = grad_norm_squared + # remember iterations (steps) start from 0. + valid_window = self.curv_win[:min(self.curv_win_width, self.step + 1)] + h_min_t = np.min(valid_window) + h_max_t = np.max(valid_window) + #print(h_min_t, h_max_t) + self.h_min_lpf = b * self.h_min_lpf + m1b * h_min_t + self.h_max_lpf = b * self.h_max_lpf + m1b * h_max_t + self.h_min = self.h_min_lpf * self.debias + self.h_max = self.h_max_lpf * self.debias + # FIXME? the first few iterations are huuuuuuuge for regression. + #print(self.h_min, self.h_max) + + # grad_variance() + self.grad_avg_lpf = b * self.grad_avg_lpf + m1b * dW + self.grad_avg_squared_lpf = b * self.grad_avg_squared_lpf + m1b * grad_squared + self.grad_avg = self.grad_avg_lpf * self.debias + self.grad_avg_squared = self.grad_avg_squared_lpf * self.debias + # FIXME: reimplement, this is weird. + #self._grad_avg = [self._moving_averager.average(dW)] + #self._grad_avg_squared = [np.square(val) for val in self._grad_avg] + #self._grad_var = self._grad_norm_squared_avg - np.add_n( [np.reduce_sum(val) for val in self._grad_avg_squared] ) + # || g^2_avg - g_avg^2 ||_1 + #self.grad_var = grad_norm_squared_avg - np.sum(self.grad_avg_squared) + # note: the abs probably isn't necessary here. + self.grad_var = np.sum(np.abs(self.grad_avg_squared - np.square(self.grad_avg))) + #print(self.grad_var) + + # dist_to_opt() + grad_norm = np.sqrt(grad_norm_squared) + self.grad_norm_avg_lpf = b * self.grad_norm_avg_lpf + m1b * grad_norm + grad_norm_avg = self.grad_norm_avg_lpf * self.debias + # single iteration distance estimation. + dist_to_opt = grad_norm_avg / grad_norm_squared_avg + # running average of distance + self.dist_to_opt_avg_lpf = b * self.dist_to_opt_avg_lpf + m1b * dist_to_opt + self.dist_to_opt_avg = self.dist_to_opt_avg_lpf * self.debias + + # update_hyper_param() + if self.step > 0: + mu_for_real = self.get_mu() + lr_for_real = self.get_lr(mu_for_real) + self.mu_lpf = b * self.mu_lpf + m1b * mu_for_real + self.alpha_lpf = b * self.alpha_lpf + m1b * lr_for_real + self.mu = self.mu_lpf * self.debias + self.alpha = self.alpha_lpf * self.debias + + ### + + self.step += 1 + self.beta_t *= self.beta + return V + # Nonparametric Layers {{{1 class AlphaDropout(Layer): diff --git a/onn_mnist.py b/onn_mnist.py index 24eae15..9e3f761 100755 --- a/onn_mnist.py +++ b/onn_mnist.py @@ -43,7 +43,7 @@ else: starts = 3 bs = 500 - learner_class = SGDR + learner_class = None #SGDR restart_decay = 0.5 n_dense = 2 @@ -51,9 +51,9 @@ else: new_dims = (4, 12) activation = Relu - reg = L1L2(3.2e-5, 3.2e-4) - final_reg = L1L2(3.2e-5, 1e-3) - dropout = 0.05 + reg = None # L1L2(3.2e-5, 3.2e-4) + final_reg = None # L1L2(3.2e-5, 1e-3) + dropout = None # 0.05 actreg_lamb = None #1e-4 load_fn = None @@ -132,7 +132,7 @@ model = Model(x, y, unsafe=True) lr *= np.sqrt(bs) -optim = Adam() +optim = YellowFin() if learner_class == SGDR: learner = learner_class(optim, epochs=epochs//starts, rate=lr, restarts=starts-1, restart_decay=restart_decay,