allow MomentumClip, SineCLR, WaveCLR in config
This commit is contained in:
parent
e7a6974829
commit
4a108a10ae
1 changed files with 21 additions and 3 deletions
24
onn.py
24
onn.py
|
@ -886,9 +886,13 @@ def optim_from_config(config):
|
|||
optim = RMSprop(mu=mu)
|
||||
elif config.optim == 'sgd':
|
||||
d1 = config.optim_decay1 if 'optim_decay1' in config else 0
|
||||
if d1 > 0:
|
||||
b1 = np.exp(-1/d1)
|
||||
optim = Momentum(mu=b1, nesterov=config.nesterov)
|
||||
clip = config.gradient_clip if 'gradient_clip' in config else 0.0
|
||||
if d1 > 0 or clip > 0:
|
||||
b1 = np.exp(-1/d1) if d1 > 0 else 0
|
||||
if clip > 0:
|
||||
optim = MomentumClip(mu=b1, nesterov=config.nesterov, clip=clip)
|
||||
else:
|
||||
optim = Momentum(mu=b1, nesterov=config.nesterov)
|
||||
else:
|
||||
optim = Optimizer()
|
||||
else:
|
||||
|
@ -904,6 +908,20 @@ def learner_from_config(config, optim, rscb):
|
|||
callback=rscb, expando=expando)
|
||||
# final learning rate isn't of interest here; it's gonna be close to 0.
|
||||
log('total epochs', learner.epochs)
|
||||
elif config.learner in ('sin', 'sine'):
|
||||
lower_rate = config.learn * 1e-5 # TODO: allow access to this.
|
||||
epochs = config.epochs * (config.restarts + 1)
|
||||
frequency = config.epochs
|
||||
learner = SineCLR(optim, epochs=epochs, frequency=frequency,
|
||||
upper_rate=config.learn, lower_rate=lower_rate,
|
||||
callback=rscb)
|
||||
elif config.learner == 'wave':
|
||||
lower_rate = config.learn * 1e-5 # TODO: allow access to this.
|
||||
epochs = config.epochs * (config.restarts + 1)
|
||||
frequency = config.epochs
|
||||
learner = WaveCLR(optim, epochs=epochs, frequency=frequency,
|
||||
upper_rate=config.learn, lower_rate=lower_rate,
|
||||
callback=rscb)
|
||||
elif config.learner == 'anneal':
|
||||
learner = AnnealingLearner(optim, epochs=config.epochs, rate=config.learn,
|
||||
halve_every=config.learn_halve_every)
|
||||
|
|
Loading…
Reference in a new issue