From 85c9b3b5c1a6a8a612a66af36077e11a2ba33c4f Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Mon, 3 Jul 2017 09:47:31 +0000 Subject: [PATCH] finish(?) implementing YellowFin --- onn.py | 169 ++++++++++++++++++++++++++------------------------- onn_mnist.py | 15 ++--- 2 files changed, 94 insertions(+), 90 deletions(-) diff --git a/onn.py b/onn.py index 1ccec65..f6a2bc0 100755 --- a/onn.py +++ b/onn.py @@ -181,138 +181,134 @@ class MomentumClip(Optimizer): else: return -self.lr * self.accum -class YellowFin(Momentum): +class YellowFin(Optimizer): # paper: https://arxiv.org/abs/1706.03471 # knowyourmeme: http://cs.stanford.edu/~zjian/project/YellowFin/ # author's implementation: https://github.com/JianGoForIt/YellowFin/blob/master/tuner_utils/yellowfin.py - def __init__(self, lr=0.1, mu=0.0, beta=0.999, curv_win_width=20): + def __init__(self, lr=0.1, mu=0.0, beta=0.999, window_size=20, + debias=True, clip=1.0): self.lr_default = _f(lr) self.mu_default = _f(mu) self.beta = _f(beta) - self.curv_win_width = int(curv_win_width) + self.window_size = int(window_size) # curv_win_width + self.debias_enabled = bool(debias) + self.clip = _f(clip) - super().__init__(lr=lr, mu=mu, nesterov=False) + self.mu = _f(mu) # momentum + super().__init__(lr) def reset(self): - super().reset() + self.accum = None + self.lr = self.lr_default self.mu = self.mu_default self.step = 0 self.beta_t = self.beta - self.curv_win = np.zeros([self.curv_win_width,], dtype=np.float32) + self.curv_win = np.zeros([self.window_size,], dtype=np.float32) self.h_min = None self.h_max = None - self.grad_norm_squared_lpf = 0 + self.g_lpf = 0 + #self.g_squared_lpf = 0 + self.g_norm_squared_lpf = 0 + self.g_norm_lpf = 0 self.h_min_lpf = 0 self.h_max_lpf = 0 - self.grad_avg_lpf = 0 - self.grad_avg_squared_lpf = 0 - self.grad_norm_avg_lpf = 0 - self.dist_to_opt_avg_lpf = 0 + self.dist_lpf = 0 + self.lr_lpf = 0 self.mu_lpf = 0 - self.alpha_lpf = 0 def get_lr(self, mu_for_real): - return np.square(1 - np.sqrt(mu_for_real)) / self.h_min + return np.square(1 - np.sqrt(mu_for_real)) / self.h_min # lr_min + # np.square(1 + np.sqrt(mu_for_real)) / self.h_max # lr_max def get_mu(self): - const_fact = np.square(self.dist_to_opt_avg) * np.square(self.h_min) / 2 / self.grad_var - #print('factor:', const_fact) + const_fact = np.square(self.dist_avg) * np.square(self.h_min) / 2 / self.g_var assert const_fact > -1e-7, "invalid factor" - coef = _f([-1, 3, -(3 + const_fact), 1]) + coef = [-1.0, 3.0, -(3.0 + const_fact), 1.0] roots = np.roots(coef) # note: returns a list of np.complex64. - # filter out the correct root. - # we're looking for a momentum value, - # so it must be a real value within (0, 1). - # a tiny imaginary value is acceptable. - land = np.logical_and - root_idx = land(land(np.real(roots) > 0, np.real(roots) < 1), np.abs(np.imag(roots)) < 1e-5) - valid_roots = roots[np.where(root_idx)] - assert len(valid_roots) > 0, 'failed to find a valid root' - # there may be two valid, duplicate roots. select one. - real_root = np.real(valid_roots[0]) - #print('selected root:', real_root) + roots = roots[np.logical_and(np.real(roots) > 0, np.real(roots) < 1)] + root = roots[np.argmin(np.imag(roots))] + assert np.absolute(root.imag) < 1e-5 + real_root = np.real(root) dr_sqrt = np.sqrt(self.h_max / self.h_min) - a, b = np.square(real_root), np.square((dr_sqrt - 1) / (dr_sqrt + 1)) + assert self.h_max >= self.h_min + a, b = np.square((dr_sqrt - 1) / (dr_sqrt + 1)), np.square(real_root) mu = max(a, b) - if b > a: - print('note: taking dr calculation') - #print('new momentum:', mu) - return _f(mu) + if a > b: + print('note: taking dr calculation. something may have exploded.') + return mu def compute(self, dW, W): - # plain momentum (pseudo-code): - #return -alpha * dW + mu * (W - W_old) - V = super().compute(dW, W) + if self.accum is None: + self.accum = np.zeros_like(dW) + + # TODO: prevent allocations everywhere by using [:]. + # assuming that really works. i haven't actually checked. + + total_norm = np.linalg.norm(dW) + clip_scale = self.clip / (total_norm + 1e-6) + if clip_scale < 1: + #print("clipping gradients; norm: {:10.7f}".format(total_norm)) + dW *= clip_scale + + self.accum[:] = self.accum * self.mu + dW + V = -self.lr * self.accum + + #fmt = 'W std: {:10.7f}e-3, dWstd: {:10.7f}e-3, V std: {:10.7f}e-3' + #print(fmt.format(np.std(W), np.std(dW) * 100, np.std(V) * 100)) b = self.beta m1b = 1 - self.beta - self.debias = 1 / (1 - self.beta_t) - #self.debias = _1 + debias = 1 / (1 - self.beta_t) if self.debias_enabled else 1 - # NOTE TO SELF: any time the reference code says "avg" they imply "lpf". + g = dW + g_squared = np.square(g) + g_norm_squared = np.sum(g_squared) + g_norm = np.sqrt(g_norm_squared) - grad_squared = dW * dW - grad_norm_squared = np.sum(grad_squared) - self.grad_norm_squared_lpf = b * self.grad_norm_squared_lpf + m1b * grad_norm_squared - grad_norm_squared_avg = self.grad_norm_squared_lpf * self.debias - - # curvature_range() - self.curv_win[self.step % self.curv_win_width] = grad_norm_squared - # remember iterations (steps) start from 0. - valid_window = self.curv_win[:min(self.curv_win_width, self.step + 1)] + self.curv_win[self.step % self.window_size] = g_norm_squared + valid_window = self.curv_win[:min(self.window_size, self.step + 1)] h_min_t = np.min(valid_window) h_max_t = np.max(valid_window) - #print(h_min_t, h_max_t) - self.h_min_lpf = b * self.h_min_lpf + m1b * h_min_t - self.h_max_lpf = b * self.h_max_lpf + m1b * h_max_t - self.h_min = self.h_min_lpf * self.debias - self.h_max = self.h_max_lpf * self.debias - # FIXME? the first few iterations are huuuuuuuge for regression. - #print(self.h_min, self.h_max) - # grad_variance() - self.grad_avg_lpf = b * self.grad_avg_lpf + m1b * dW - self.grad_avg_squared_lpf = b * self.grad_avg_squared_lpf + m1b * grad_squared - self.grad_avg = self.grad_avg_lpf * self.debias - self.grad_avg_squared = self.grad_avg_squared_lpf * self.debias - # FIXME: reimplement, this is weird. - #self._grad_avg = [self._moving_averager.average(dW)] - #self._grad_avg_squared = [np.square(val) for val in self._grad_avg] - #self._grad_var = self._grad_norm_squared_avg - np.add_n( [np.reduce_sum(val) for val in self._grad_avg_squared] ) - # || g^2_avg - g_avg^2 ||_1 - #self.grad_var = grad_norm_squared_avg - np.sum(self.grad_avg_squared) - # note: the abs probably isn't necessary here. - self.grad_var = np.sum(np.abs(self.grad_avg_squared - np.square(self.grad_avg))) - #print(self.grad_var) + self.g_lpf = b * self.g_lpf + m1b * g + #self.g_squared_lpf = b * self.g_squared_lpf + m1b * g_squared + self.g_norm_squared_lpf = b * self.g_norm_squared_lpf + m1b * g_norm_squared + self.g_norm_lpf = b * self.g_norm_lpf + m1b * g_norm + self.h_min_lpf = b * self.h_min_lpf + m1b * h_min_t + self.h_max_lpf = b * self.h_max_lpf + m1b * h_max_t - # dist_to_opt() - grad_norm = np.sqrt(grad_norm_squared) - self.grad_norm_avg_lpf = b * self.grad_norm_avg_lpf + m1b * grad_norm - grad_norm_avg = self.grad_norm_avg_lpf * self.debias - # single iteration distance estimation. - dist_to_opt = grad_norm_avg / grad_norm_squared_avg - # running average of distance - self.dist_to_opt_avg_lpf = b * self.dist_to_opt_avg_lpf + m1b * dist_to_opt - self.dist_to_opt_avg = self.dist_to_opt_avg_lpf * self.debias + g_avg = debias * self.g_lpf + #g_squared_avg = debias * self.g_squared_lpf + g_norm_squared_avg = debias * self.g_norm_squared_lpf + g_norm_avg = debias * self.g_norm_lpf + self.h_min = debias * self.h_min_lpf + self.h_max = debias * self.h_max_lpf + + dist = g_norm_avg / g_norm_squared_avg + + self.dist_lpf = b * self.dist_lpf + m1b * dist + + self.dist_avg = debias * self.dist_lpf + + self.g_var = g_norm_squared_avg - np.sum(np.square(g_avg)) + # alternatively: + #self.g_var = np.sum(np.abs(g_squared_avg - np.square(g_avg))) - # update_hyper_param() if self.step > 0: mu_for_real = self.get_mu() lr_for_real = self.get_lr(mu_for_real) self.mu_lpf = b * self.mu_lpf + m1b * mu_for_real - self.alpha_lpf = b * self.alpha_lpf + m1b * lr_for_real - self.mu = self.mu_lpf * self.debias - self.alpha = self.alpha_lpf * self.debias - - ### + self.lr_lpf = b * self.lr_lpf + m1b * lr_for_real + self.mu = debias * self.mu_lpf + self.lr = debias * self.lr_lpf self.step += 1 self.beta_t *= self.beta @@ -854,6 +850,13 @@ def optim_from_config(config): b1 = np.exp(-1/d1) b2 = np.exp(-1/d2) optim = FTML(b1=b1, b2=b2) + elif config.optim == 'yf': + d1 = config.optim_decay1 if 'optim_decay1' in config else 999.5 + d2 = config.optim_decay2 if 'optim_decay2' in config else 999.5 + if d1 != d2: + raise Exception("yellowfin only uses one decay term.") + beta = np.exp(-1/d1) + optim = YellowFin(beta=beta) elif config.optim in ('rms', 'rmsprop'): d2 = config.optim_decay2 if 'optim_decay2' in config else 99.5 mu = np.exp(-1/d2) diff --git a/onn_mnist.py b/onn_mnist.py index 2c3b246..eb47cdd 100755 --- a/onn_mnist.py +++ b/onn_mnist.py @@ -46,19 +46,19 @@ else: learner_class = None #SGDR restart_decay = 0.5 - n_dense = 2 + n_dense = 1 n_denses = 0 new_dims = (4, 12) - activation = Relu + activation = Relu # GeluApprox - reg = None # L1L2(3.2e-5, 3.2e-4) - final_reg = None # L1L2(3.2e-5, 1e-3) + reg = None # L1L2(1e-6, 1e-5) # L1L2(3.2e-5, 3.2e-4) + final_reg = None # L1L2(1e-6, 1e-5) # L1L2(3.2e-5, 1e-3) dropout = None # 0.05 actreg_lamb = None #1e-4 load_fn = None - save_fn = 'mnist.h5' - log_fn = 'mnist_losses.npz' + save_fn = 'mnist3.h5' + log_fn = 'mnist_losses3.npz' fn = 'mnist.npz' mnist_dim = 28 @@ -132,7 +132,8 @@ model = Model(x, y, unsafe=True) lr *= np.sqrt(bs) -optim = YellowFin() +optim = YellowFin() #beta=np.exp(-1/240) +#optim = MomentumClip(0.8, 0.8) if learner_class == SGDR: learner = learner_class(optim, epochs=epochs//starts, rate=lr, restarts=starts-1, restart_decay=restart_decay,