.
This commit is contained in:
parent
361aefcb29
commit
3b27a1a05e
1 changed files with 135 additions and 52 deletions
187
optim_nn.py
187
optim_nn.py
|
@ -1,13 +1,16 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import numpy as np
|
||||
# ugly shorthand:
|
||||
nf = np.float32
|
||||
nfa = lambda x: np.array(x, dtype=nf)
|
||||
ni = np.int
|
||||
nia = lambda x: np.array(x, dtype=ni)
|
||||
|
||||
# just for speed, not strictly essential:
|
||||
from scipy.special import expit as sigmoid
|
||||
|
||||
# used for numbering layers like Keras:
|
||||
from collections import defaultdict
|
||||
|
||||
# Initializations
|
||||
|
@ -47,7 +50,7 @@ class SquaredHalved(Loss):
|
|||
return r
|
||||
|
||||
class SomethingElse(Loss):
|
||||
# generalizes Absolute and SquaredHalved
|
||||
# generalizes Absolute and SquaredHalved (|dx| = 1)
|
||||
# plot: https://www.desmos.com/calculator/fagjg9vuz7
|
||||
def __init__(self, a=4/3):
|
||||
assert 1 <= a <= 2, "parameter out of range"
|
||||
|
@ -99,7 +102,7 @@ class Momentum(Optimizer):
|
|||
|
||||
V = self.mu * self.dWprev - self.alpha * (dW + W * self.lamb)
|
||||
self.dWprev[:] = V
|
||||
if self.nesterov:
|
||||
if self.nesterov: # TODO: is this correct? looks weird
|
||||
return self.mu * V - self.alpha * (dW + W * self.lamb)
|
||||
else:
|
||||
return V
|
||||
|
@ -528,17 +531,14 @@ class Model:
|
|||
|
||||
f.close()
|
||||
|
||||
class Ritual:
|
||||
def __init__(self,
|
||||
optim=None,
|
||||
learn_rate=1e-3, learn_anneal=1, learn_advance=0,
|
||||
loss=None, mloss=None):
|
||||
self.optim = optim if optim is not None else SGD()
|
||||
class Ritual: # i'm just making up names at this point
|
||||
def __init__(self, learner=None, loss=None, mloss=None):
|
||||
self.learner = learner if learner is not None else Learner(Optimizer())
|
||||
self.loss = loss if loss is not None else Squared()
|
||||
self.mloss = mloss if mloss is not None else loss
|
||||
self.learn_rate = nf(learn_rate)
|
||||
self.learn_anneal = nf(learn_anneal)
|
||||
self.learn_advance = nf(learn_advance)
|
||||
|
||||
def reset(self):
|
||||
self.learner.reset(optim=True)
|
||||
|
||||
def measure(self, residual):
|
||||
return self.mloss.mean(residual)
|
||||
|
@ -546,17 +546,6 @@ class Ritual:
|
|||
def derive(self, residual):
|
||||
return self.loss.dmean(residual)
|
||||
|
||||
def update(self, dW, W):
|
||||
self.optim.update(dW, W)
|
||||
|
||||
def prepare(self, epoch):
|
||||
self.optim.alpha = self.learn_rate * self.learn_anneal**epoch
|
||||
|
||||
def restart(self, optim=False):
|
||||
self.learn_rate *= self.learn_anneal**self.learn_advance
|
||||
if optim:
|
||||
self.optim.reset()
|
||||
|
||||
def train_batched(self, model, inputs, outputs, batch_size, return_losses=False):
|
||||
cumsum_loss = 0
|
||||
batch_count = inputs.shape[0] // batch_size
|
||||
|
@ -570,7 +559,7 @@ class Ritual:
|
|||
residual = predicted - batch_outputs
|
||||
|
||||
model.backward(self.derive(residual))
|
||||
self.update(model.dW, model.W)
|
||||
self.learner.optim.update(model.dW, model.W)
|
||||
|
||||
batch_loss = self.measure(residual)
|
||||
if np.isnan(batch_loss):
|
||||
|
@ -584,6 +573,89 @@ class Ritual:
|
|||
else:
|
||||
return avg_loss
|
||||
|
||||
class Learner:
|
||||
def __init__(self, optim, epochs=100, rate=None):
|
||||
assert isinstance(optim, Optimizer)
|
||||
self.optim = optim
|
||||
self.start_rate = optim.alpha if rate is None else float(rate)
|
||||
self.epochs = int(epochs)
|
||||
self.reset()
|
||||
|
||||
def reset(self, optim=False):
|
||||
self.started = False
|
||||
self.epoch = 0
|
||||
if optim:
|
||||
self.optim.reset()
|
||||
|
||||
@property
|
||||
def epoch(self):
|
||||
return self._epoch
|
||||
|
||||
@epoch.setter
|
||||
def epoch(self, new_epoch):
|
||||
self._epoch = int(new_epoch)
|
||||
self.rate = self.rate_at(self._epoch)
|
||||
|
||||
@property
|
||||
def rate(self):
|
||||
return self.optim.alpha
|
||||
|
||||
@rate.setter
|
||||
def rate(self, new_rate):
|
||||
self.optim.alpha = new_rate
|
||||
|
||||
def rate_at(self, epoch):
|
||||
return self.start_rate
|
||||
|
||||
def next(self):
|
||||
# returns whether or not to continue training. updates.
|
||||
if self.epoch + 1 >= self.epochs:
|
||||
return False
|
||||
if self.started:
|
||||
self.epoch += 1
|
||||
else:
|
||||
self.started = True
|
||||
self.epoch = self.epoch # poke property setter just in case
|
||||
return True
|
||||
|
||||
@property
|
||||
def final_rate(self):
|
||||
return self.rate_at(self.epochs - 1)
|
||||
|
||||
class AnnealingLearner(Learner):
|
||||
def __init__(self, optim, epochs=100, rate=None, halve_every=10):
|
||||
self.halve_every = float(halve_every)
|
||||
self.anneal = 0.5**(1/self.halve_every)
|
||||
super().__init__(optim, epochs, rate)
|
||||
|
||||
def rate_at(self, epoch):
|
||||
return super().rate_at(epoch) * self.anneal**epoch
|
||||
|
||||
class DumbLearner(AnnealingLearner):
|
||||
# this is my own awful contraption. it's not really "SGD with restarts".
|
||||
def __init__(self, optim, epochs=100, rate=None, halve_every=10, restarts=0, restart_advance=20, callback=None):
|
||||
self.restart_epochs = int(epochs)
|
||||
self.restarts = int(restarts)
|
||||
self.restart_advance = float(restart_advance)
|
||||
self.restart_callback = callback
|
||||
epochs = self.restart_epochs * (self.restarts + 1)
|
||||
super().__init__(optim, epochs, rate, halve_every)
|
||||
|
||||
def rate_at(self, epoch):
|
||||
sub_epoch = epoch % self.restart_epochs
|
||||
restart = epoch // self.restart_epochs
|
||||
return super().rate_at(sub_epoch) * (self.anneal**self.restart_advance)**restart
|
||||
|
||||
def next(self):
|
||||
if not super().next():
|
||||
return False
|
||||
sub_epoch = self.epoch % self.restart_epochs
|
||||
restart = self.epoch // self.restart_epochs
|
||||
if restart > 0 and sub_epoch == 0:
|
||||
if self.restart_callback is not None:
|
||||
self.restart_callback(restart)
|
||||
return True
|
||||
|
||||
def multiresnet(x, width, depth, block=2, multi=1,
|
||||
activation=Relu, style='batchless',
|
||||
init=init_he_normal):
|
||||
|
@ -725,12 +797,16 @@ def run(program, args=[]):
|
|||
print(str(node)+sep+('\n'+str(node)+sep).join(children))
|
||||
log('parameters', model.param_count)
|
||||
|
||||
#
|
||||
|
||||
training = config.epochs > 0 and config.restarts >= 0
|
||||
|
||||
if config.fn_load is not None:
|
||||
log('loading weights', config.fn_load)
|
||||
model.load_weights(config.fn_load)
|
||||
|
||||
#
|
||||
|
||||
if config.optim == 'adam':
|
||||
assert not config.nesterov, "unimplemented"
|
||||
optim = Adam()
|
||||
|
@ -742,6 +818,22 @@ def run(program, args=[]):
|
|||
else:
|
||||
raise Exception('unknown optimizer', config.optim)
|
||||
|
||||
def rscb(restart):
|
||||
measure_error() # declared later...
|
||||
log("restarting", restart)
|
||||
if config.restart_optim:
|
||||
optim.reset()
|
||||
|
||||
#
|
||||
|
||||
learner = DumbLearner(optim, epochs=config.epochs, rate=config.learn,
|
||||
halve_every=config.learn_halve_every,
|
||||
restarts=config.restarts, restart_advance=config.learn_restart_advance,
|
||||
callback=rscb)
|
||||
log("final learning rate", "{:10.8f}".format(learner.final_rate))
|
||||
|
||||
#
|
||||
|
||||
def lookup_loss(maybe_name):
|
||||
if isinstance(maybe_name, Loss):
|
||||
return maybe_name
|
||||
|
@ -754,14 +846,7 @@ def run(program, args=[]):
|
|||
loss = lookup_loss(config.loss)
|
||||
mloss = lookup_loss(config.mloss) if config.mloss else loss
|
||||
|
||||
anneal = 0.5**(1/config.learn_halve_every)
|
||||
ritual = Ritual(optim=optim,
|
||||
learn_rate=config.learn, learn_anneal=anneal,
|
||||
learn_advance=config.learn_restart_advance,
|
||||
loss=loss, mloss=mloss)
|
||||
|
||||
learn_end = config.learn * (anneal**config.learn_restart_advance)**config.restarts * anneal**(config.epochs - 1)
|
||||
log("final learning rate", "{:10.8f}".format(learn_end))
|
||||
ritual = Ritual(learner=learner, loss=loss, mloss=mloss)
|
||||
|
||||
# Training
|
||||
|
||||
|
@ -788,30 +873,26 @@ def run(program, args=[]):
|
|||
train_losses.append(train_err)
|
||||
valid_losses.append(valid_err)
|
||||
|
||||
for i in range(config.restarts + 1):
|
||||
measure_error()
|
||||
measure_error()
|
||||
|
||||
if i > 0:
|
||||
log("restarting", i)
|
||||
ritual.restart(optim=config.restart_optim)
|
||||
assert inputs.shape[0] % config.batch_size == 0, \
|
||||
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
|
||||
while learner.next():
|
||||
indices = np.arange(inputs.shape[0])
|
||||
np.random.shuffle(indices)
|
||||
shuffled_inputs = inputs[indices] / x_scale
|
||||
shuffled_outputs = outputs[indices] / y_scale
|
||||
|
||||
assert inputs.shape[0] % config.batch_size == 0, \
|
||||
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
|
||||
for e in range(config.epochs):
|
||||
indices = np.arange(inputs.shape[0])
|
||||
np.random.shuffle(indices)
|
||||
shuffled_inputs = inputs[indices] / x_scale
|
||||
shuffled_outputs = outputs[indices] / y_scale
|
||||
avg_loss, losses = ritual.train_batched(model,
|
||||
shuffled_inputs, shuffled_outputs,
|
||||
config.batch_size,
|
||||
return_losses=True)
|
||||
batch_losses += losses
|
||||
|
||||
ritual.prepare(e)
|
||||
#log("learning rate", "{:10.8f}".format(ritual.optim.alpha))
|
||||
|
||||
avg_loss, losses = ritual.train_batched(model,
|
||||
shuffled_inputs, shuffled_outputs,
|
||||
config.batch_size,
|
||||
return_losses=True)
|
||||
log("average loss", "{:11.7f}".format(avg_loss))
|
||||
batch_losses += losses
|
||||
#log("learning rate", "{:10.8f}".format(learner.rate))
|
||||
#log("average loss", "{:11.7f}".format(avg_loss))
|
||||
fmt = "epoch {:4.0f}, rate {:10.8f}, loss {:11.7f}"
|
||||
log("info", fmt.format(learner.epoch + 1, learner.rate, avg_loss))
|
||||
|
||||
measure_error()
|
||||
|
||||
|
@ -821,6 +902,8 @@ def run(program, args=[]):
|
|||
|
||||
# Evaluation
|
||||
|
||||
# this is just an example/test of how to predict a single output;
|
||||
# it doesn't measure the quality of the network or anything.
|
||||
a = (192, 128, 64)
|
||||
b = (64, 128, 192)
|
||||
X = np.expand_dims(np.hstack((a, b)), 0) / x_scale
|
||||
|
|
Loading…
Reference in a new issue