diff --git a/onn.py b/onn.py index 37aa304..90a17e3 100755 --- a/onn.py +++ b/onn.py @@ -1034,6 +1034,10 @@ def optim_from_config(config): d2 = config.optim_decay2 if 'optim_decay2' in config else 99.5 mu = np.exp(-1/d2) optim = RMSprop(mu=mu) + elif config.optim == 'rmsc': + d2 = config.optim_decay2 if 'optim_decay2' in config else 9.5 + mu = np.exp(-1/d2) + optim = RMSpropCentered(momentum=mu) elif config.optim == 'sgd': d1 = config.optim_decay1 if 'optim_decay1' in config else 0 clip = config.gradient_clip if 'gradient_clip' in config else 0.0