diff --git a/onn.py b/onn.py index fa718c1..eeb209c 100755 --- a/onn.py +++ b/onn.py @@ -886,9 +886,13 @@ def optim_from_config(config): optim = RMSprop(mu=mu) elif config.optim == 'sgd': d1 = config.optim_decay1 if 'optim_decay1' in config else 0 - if d1 > 0: - b1 = np.exp(-1/d1) - optim = Momentum(mu=b1, nesterov=config.nesterov) + clip = config.gradient_clip if 'gradient_clip' in config else 0.0 + if d1 > 0 or clip > 0: + b1 = np.exp(-1/d1) if d1 > 0 else 0 + if clip > 0: + optim = MomentumClip(mu=b1, nesterov=config.nesterov, clip=clip) + else: + optim = Momentum(mu=b1, nesterov=config.nesterov) else: optim = Optimizer() else: @@ -904,6 +908,20 @@ def learner_from_config(config, optim, rscb): callback=rscb, expando=expando) # final learning rate isn't of interest here; it's gonna be close to 0. log('total epochs', learner.epochs) + elif config.learner in ('sin', 'sine'): + lower_rate = config.learn * 1e-5 # TODO: allow access to this. + epochs = config.epochs * (config.restarts + 1) + frequency = config.epochs + learner = SineCLR(optim, epochs=epochs, frequency=frequency, + upper_rate=config.learn, lower_rate=lower_rate, + callback=rscb) + elif config.learner == 'wave': + lower_rate = config.learn * 1e-5 # TODO: allow access to this. + epochs = config.epochs * (config.restarts + 1) + frequency = config.epochs + learner = WaveCLR(optim, epochs=epochs, frequency=frequency, + upper_rate=config.learn, lower_rate=lower_rate, + callback=rscb) elif config.learner == 'anneal': learner = AnnealingLearner(optim, epochs=config.epochs, rate=config.learn, halve_every=config.learn_halve_every)