diff --git a/onn/optimizer.py b/onn/optimizer.py index 14d0d1e..2aa05f8 100644 --- a/onn/optimizer.py +++ b/onn/optimizer.py @@ -248,141 +248,6 @@ class MomentumClip(Optimizer): return -self.lr * self.accum -class YellowFin(Optimizer): - # paper: https://arxiv.org/abs/1706.03471 - # knowyourmeme: http://cs.stanford.edu/~zjian/project/YellowFin/ - # author's implementation: - # https://github.com/JianGoForIt/YellowFin/blob/master/tuner_utils/yellowfin.py - # code lifted: - # https://gist.github.com/botev/f8b32c00eafee222e47393f7f0747666 - - def __init__(self, lr=0.1, mu=0.0, beta=0.999, window_size=20, - debias=True, clip=1.0): - self.lr_default = _f(lr) - self.mu_default = _f(mu) - self.beta = _f(beta) - self.window_size = int(window_size) # curv_win_width - self.debias_enabled = bool(debias) - self.clip = _f(clip) - - self.mu = _f(mu) # momentum - super().__init__(lr) - - def reset(self): - self.accum = None - - self.lr = self.lr_default - self.mu = self.mu_default - - self.step = 0 - self.beta_t = self.beta - - self.curv_win = np.zeros([self.window_size, ], dtype=np.float32) - - self.h_min = None - self.h_max = None - - self.g_lpf = 0 - # self.g_squared_lpf = 0 - self.g_norm_squared_lpf = 0 - self.g_norm_lpf = 0 - self.h_min_lpf = 0 - self.h_max_lpf = 0 - self.dist_lpf = 0 - self.lr_lpf = 0 - self.mu_lpf = 0 - - def get_lr_mu(self): - p = (np.square(self.dist_avg) * np.square(self.h_min)) \ - / (2 * self.g_var) - w3 = p * (np.sqrt(0.25 + p / 27.0) - 0.5) - w = np.power(w3, 1/3) - y = w - p / (3 * w) - sqrt_mu1 = y + 1 - - sqrt_h_min = np.sqrt(self.h_min) - sqrt_h_max = np.sqrt(self.h_max) - sqrt_mu2 = (sqrt_h_max - sqrt_h_min) / (sqrt_h_max + sqrt_h_min) - - sqrt_mu = max(sqrt_mu1, sqrt_mu2) - if sqrt_mu2 > sqrt_mu1: - print('note: taking dr calculation. something may have exploded.') - - lr = np.square(1 - sqrt_mu) / self.h_min - mu = np.square(sqrt_mu) - return lr, mu - - def compute(self, dW, W): - if self.accum is None: - self.accum = np.zeros_like(dW) - - # TODO: prevent allocations everywhere by using [:]. - # assuming that really works. i haven't actually checked. - - total_norm = np.linalg.norm(dW) - clip_scale = self.clip / (total_norm + 1e-6) - if clip_scale < 1: - # print("clipping gradients; norm: {:10.5f}".format(total_norm)) - dW *= clip_scale - - # fmt = 'W std: {:10.7f}e-3, dWstd: {:10.7f}e-3, V std: {:10.7f}e-3' - # print(fmt.format(np.std(W), np.std(dW) * 100, np.std(V) * 100)) - - b = self.beta - m1b = 1 - self.beta - debias = 1 / (1 - self.beta_t) if self.debias_enabled else 1 - - g = dW - g_squared = np.square(g) - g_norm_squared = np.sum(g_squared) - g_norm = np.sqrt(g_norm_squared) - - self.curv_win[self.step % self.window_size] = g_norm_squared - valid_window = self.curv_win[:min(self.window_size, self.step + 1)] - h_min_t = np.min(valid_window) - h_max_t = np.max(valid_window) - - self.g_lpf = b * self.g_lpf + m1b * g - # self.g_squared_lpf = b * self.g_squared_lpf + m1b * g_squared - self.g_norm_squared_lpf = b * self.g_norm_squared_lpf \ - + m1b * g_norm_squared - self.g_norm_lpf = b * self.g_norm_lpf + m1b * g_norm - self.h_min_lpf = b * self.h_min_lpf + m1b * h_min_t - self.h_max_lpf = b * self.h_max_lpf + m1b * h_max_t - - g_avg = debias * self.g_lpf - # g_squared_avg = debias * self.g_squared_lpf - g_norm_squared_avg = debias * self.g_norm_squared_lpf - g_norm_avg = debias * self.g_norm_lpf - self.h_min = debias * self.h_min_lpf - self.h_max = debias * self.h_max_lpf - assert self.h_max >= self.h_min - - dist = g_norm_avg / g_norm_squared_avg - - self.dist_lpf = b * self.dist_lpf + m1b * dist - - self.dist_avg = debias * self.dist_lpf - - self.g_var = g_norm_squared_avg - np.sum(np.square(g_avg)) - # equivalently: - # self.g_var = np.sum(np.abs(g_squared_avg - np.square(g_avg))) - - if self.step > 0: - lr_for_real, mu_for_real = self.get_lr_mu() - self.mu_lpf = b * self.mu_lpf + m1b * mu_for_real - self.lr_lpf = b * self.lr_lpf + m1b * lr_for_real - self.mu = debias * self.mu_lpf - self.lr = debias * self.lr_lpf - - self.accum[:] = self.accum * self.mu - self.lr * dW - V = self.accum - - self.step += 1 - self.beta_t *= self.beta - return V - - class AddSign(Optimizer): # paper: https://arxiv.org/abs/1709.07417