.
This commit is contained in:
parent
65fe5cad85
commit
205d64a8a0
|
@ -874,3 +874,39 @@ class SGDR(Learner):
|
||||||
if self.restart_callback is not None:
|
if self.restart_callback is not None:
|
||||||
self.restart_callback(restart)
|
self.restart_callback(restart)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
class TriangularCLR(Learner):
|
||||||
|
# note: i haven't actually read (nor seen) the paper(s) on CLR,
|
||||||
|
# but this case (triangular) should be pretty difficult to get wrong.
|
||||||
|
|
||||||
|
per_batch = True
|
||||||
|
|
||||||
|
def __init__(self, optim, epochs=400, upper_rate=None, lower_rate=0,
|
||||||
|
frequency=100, callback=None):
|
||||||
|
# NOTE: start_rate is treated as upper_rate
|
||||||
|
self.frequency = int(frequency)
|
||||||
|
assert self.frequency > 0
|
||||||
|
self.callback = callback
|
||||||
|
self.lower_rate = _f(lower_rate)
|
||||||
|
super().__init__(optim, epochs, upper_rate)
|
||||||
|
|
||||||
|
def _t(self, epoch):
|
||||||
|
# NOTE: this could probably be simplified
|
||||||
|
offset = self.frequency / 2
|
||||||
|
return np.abs(((epoch + offset) % self.frequency) - offset) / offset
|
||||||
|
|
||||||
|
def rate_at(self, epoch):
|
||||||
|
# NOTE: start_rate is treated as upper_rate
|
||||||
|
return self._t(epoch) * (self.start_rate - self.lower_rate) + self.lower_rate
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
if not super().next():
|
||||||
|
return False
|
||||||
|
if self.epoch > 1 and self.epoch % self.frequency == 0:
|
||||||
|
if self.callback is not None:
|
||||||
|
self.callback(self.epoch // self.frequency)
|
||||||
|
return True
|
||||||
|
|
||||||
|
class SineCLR(TriangularCLR):
|
||||||
|
def _t(self, epoch):
|
||||||
|
return np.sin(_pi * _inv2 * super()._t(epoch))
|
||||||
|
|
|
@ -5,15 +5,15 @@ from optim_nn_core import _f
|
||||||
|
|
||||||
#np.random.seed(42069)
|
#np.random.seed(42069)
|
||||||
|
|
||||||
# train loss: 4.194040e-02
|
# train loss: 7.048363e-03
|
||||||
# train accuracy: 99.46%
|
# train accuracy: 99.96%
|
||||||
# valid loss: 1.998158e-01
|
# valid loss: 3.062232e-01
|
||||||
# valid accuracy: 97.26%
|
# valid accuracy: 97.22%
|
||||||
# TODO: add dropout or something to lessen overfitting
|
# TODO: add dropout or something to lessen overfitting
|
||||||
|
|
||||||
lr = 0.01
|
lr = 0.0032
|
||||||
epochs = 24
|
epochs = 125
|
||||||
starts = 2
|
starts = 5
|
||||||
restart_decay = 0.5
|
restart_decay = 0.5
|
||||||
bs = 100
|
bs = 100
|
||||||
|
|
||||||
|
@ -64,9 +64,15 @@ y = y.feed(Softmax())
|
||||||
model = Model(x, y, unsafe=True)
|
model = Model(x, y, unsafe=True)
|
||||||
|
|
||||||
optim = Adam()
|
optim = Adam()
|
||||||
learner = SGDR(optim, epochs=epochs//starts, rate=lr,
|
if 0:
|
||||||
restarts=starts - 1, restart_decay=restart_decay,
|
learner = SGDR(optim, epochs=epochs//starts, rate=lr,
|
||||||
expando=lambda i:0)
|
restarts=starts-1, restart_decay=restart_decay,
|
||||||
|
expando=lambda i:0)
|
||||||
|
else:
|
||||||
|
# learner = TriangularCLR(optim, epochs=epochs, lower_rate=0, upper_rate=lr,
|
||||||
|
# frequency=epochs//starts)
|
||||||
|
learner = SineCLR(optim, epochs=epochs, lower_rate=0, upper_rate=lr,
|
||||||
|
frequency=epochs//starts)
|
||||||
|
|
||||||
loss = CategoricalCrossentropy()
|
loss = CategoricalCrossentropy()
|
||||||
mloss = Accuracy()
|
mloss = Accuracy()
|
||||||
|
@ -89,9 +95,10 @@ def measure_error(quiet=False):
|
||||||
log(name + " accuracy", "{:6.2f}%".format(mloss * 100))
|
log(name + " accuracy", "{:6.2f}%".format(mloss * 100))
|
||||||
return loss, mloss
|
return loss, mloss
|
||||||
|
|
||||||
loss, mloss = print_error("train", inputs, outputs)
|
if not quiet:
|
||||||
train_losses.append(loss)
|
loss, mloss = print_error("train", inputs, outputs)
|
||||||
train_mlosses.append(mloss)
|
train_losses.append(loss)
|
||||||
|
train_mlosses.append(mloss)
|
||||||
loss, mloss = print_error("valid", valid_inputs, valid_outputs)
|
loss, mloss = print_error("valid", valid_inputs, valid_outputs)
|
||||||
valid_losses.append(loss)
|
valid_losses.append(loss)
|
||||||
valid_mlosses.append(mloss)
|
valid_mlosses.append(mloss)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user