.
This commit is contained in:
parent
361aefcb29
commit
3b27a1a05e
1 changed files with 135 additions and 52 deletions
165
optim_nn.py
165
optim_nn.py
|
@ -1,13 +1,16 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
# ugly shorthand:
|
||||||
nf = np.float32
|
nf = np.float32
|
||||||
nfa = lambda x: np.array(x, dtype=nf)
|
nfa = lambda x: np.array(x, dtype=nf)
|
||||||
ni = np.int
|
ni = np.int
|
||||||
nia = lambda x: np.array(x, dtype=ni)
|
nia = lambda x: np.array(x, dtype=ni)
|
||||||
|
|
||||||
|
# just for speed, not strictly essential:
|
||||||
from scipy.special import expit as sigmoid
|
from scipy.special import expit as sigmoid
|
||||||
|
|
||||||
|
# used for numbering layers like Keras:
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
# Initializations
|
# Initializations
|
||||||
|
@ -47,7 +50,7 @@ class SquaredHalved(Loss):
|
||||||
return r
|
return r
|
||||||
|
|
||||||
class SomethingElse(Loss):
|
class SomethingElse(Loss):
|
||||||
# generalizes Absolute and SquaredHalved
|
# generalizes Absolute and SquaredHalved (|dx| = 1)
|
||||||
# plot: https://www.desmos.com/calculator/fagjg9vuz7
|
# plot: https://www.desmos.com/calculator/fagjg9vuz7
|
||||||
def __init__(self, a=4/3):
|
def __init__(self, a=4/3):
|
||||||
assert 1 <= a <= 2, "parameter out of range"
|
assert 1 <= a <= 2, "parameter out of range"
|
||||||
|
@ -99,7 +102,7 @@ class Momentum(Optimizer):
|
||||||
|
|
||||||
V = self.mu * self.dWprev - self.alpha * (dW + W * self.lamb)
|
V = self.mu * self.dWprev - self.alpha * (dW + W * self.lamb)
|
||||||
self.dWprev[:] = V
|
self.dWprev[:] = V
|
||||||
if self.nesterov:
|
if self.nesterov: # TODO: is this correct? looks weird
|
||||||
return self.mu * V - self.alpha * (dW + W * self.lamb)
|
return self.mu * V - self.alpha * (dW + W * self.lamb)
|
||||||
else:
|
else:
|
||||||
return V
|
return V
|
||||||
|
@ -528,17 +531,14 @@ class Model:
|
||||||
|
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
class Ritual:
|
class Ritual: # i'm just making up names at this point
|
||||||
def __init__(self,
|
def __init__(self, learner=None, loss=None, mloss=None):
|
||||||
optim=None,
|
self.learner = learner if learner is not None else Learner(Optimizer())
|
||||||
learn_rate=1e-3, learn_anneal=1, learn_advance=0,
|
|
||||||
loss=None, mloss=None):
|
|
||||||
self.optim = optim if optim is not None else SGD()
|
|
||||||
self.loss = loss if loss is not None else Squared()
|
self.loss = loss if loss is not None else Squared()
|
||||||
self.mloss = mloss if mloss is not None else loss
|
self.mloss = mloss if mloss is not None else loss
|
||||||
self.learn_rate = nf(learn_rate)
|
|
||||||
self.learn_anneal = nf(learn_anneal)
|
def reset(self):
|
||||||
self.learn_advance = nf(learn_advance)
|
self.learner.reset(optim=True)
|
||||||
|
|
||||||
def measure(self, residual):
|
def measure(self, residual):
|
||||||
return self.mloss.mean(residual)
|
return self.mloss.mean(residual)
|
||||||
|
@ -546,17 +546,6 @@ class Ritual:
|
||||||
def derive(self, residual):
|
def derive(self, residual):
|
||||||
return self.loss.dmean(residual)
|
return self.loss.dmean(residual)
|
||||||
|
|
||||||
def update(self, dW, W):
|
|
||||||
self.optim.update(dW, W)
|
|
||||||
|
|
||||||
def prepare(self, epoch):
|
|
||||||
self.optim.alpha = self.learn_rate * self.learn_anneal**epoch
|
|
||||||
|
|
||||||
def restart(self, optim=False):
|
|
||||||
self.learn_rate *= self.learn_anneal**self.learn_advance
|
|
||||||
if optim:
|
|
||||||
self.optim.reset()
|
|
||||||
|
|
||||||
def train_batched(self, model, inputs, outputs, batch_size, return_losses=False):
|
def train_batched(self, model, inputs, outputs, batch_size, return_losses=False):
|
||||||
cumsum_loss = 0
|
cumsum_loss = 0
|
||||||
batch_count = inputs.shape[0] // batch_size
|
batch_count = inputs.shape[0] // batch_size
|
||||||
|
@ -570,7 +559,7 @@ class Ritual:
|
||||||
residual = predicted - batch_outputs
|
residual = predicted - batch_outputs
|
||||||
|
|
||||||
model.backward(self.derive(residual))
|
model.backward(self.derive(residual))
|
||||||
self.update(model.dW, model.W)
|
self.learner.optim.update(model.dW, model.W)
|
||||||
|
|
||||||
batch_loss = self.measure(residual)
|
batch_loss = self.measure(residual)
|
||||||
if np.isnan(batch_loss):
|
if np.isnan(batch_loss):
|
||||||
|
@ -584,6 +573,89 @@ class Ritual:
|
||||||
else:
|
else:
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
|
class Learner:
|
||||||
|
def __init__(self, optim, epochs=100, rate=None):
|
||||||
|
assert isinstance(optim, Optimizer)
|
||||||
|
self.optim = optim
|
||||||
|
self.start_rate = optim.alpha if rate is None else float(rate)
|
||||||
|
self.epochs = int(epochs)
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self, optim=False):
|
||||||
|
self.started = False
|
||||||
|
self.epoch = 0
|
||||||
|
if optim:
|
||||||
|
self.optim.reset()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def epoch(self):
|
||||||
|
return self._epoch
|
||||||
|
|
||||||
|
@epoch.setter
|
||||||
|
def epoch(self, new_epoch):
|
||||||
|
self._epoch = int(new_epoch)
|
||||||
|
self.rate = self.rate_at(self._epoch)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rate(self):
|
||||||
|
return self.optim.alpha
|
||||||
|
|
||||||
|
@rate.setter
|
||||||
|
def rate(self, new_rate):
|
||||||
|
self.optim.alpha = new_rate
|
||||||
|
|
||||||
|
def rate_at(self, epoch):
|
||||||
|
return self.start_rate
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
# returns whether or not to continue training. updates.
|
||||||
|
if self.epoch + 1 >= self.epochs:
|
||||||
|
return False
|
||||||
|
if self.started:
|
||||||
|
self.epoch += 1
|
||||||
|
else:
|
||||||
|
self.started = True
|
||||||
|
self.epoch = self.epoch # poke property setter just in case
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def final_rate(self):
|
||||||
|
return self.rate_at(self.epochs - 1)
|
||||||
|
|
||||||
|
class AnnealingLearner(Learner):
|
||||||
|
def __init__(self, optim, epochs=100, rate=None, halve_every=10):
|
||||||
|
self.halve_every = float(halve_every)
|
||||||
|
self.anneal = 0.5**(1/self.halve_every)
|
||||||
|
super().__init__(optim, epochs, rate)
|
||||||
|
|
||||||
|
def rate_at(self, epoch):
|
||||||
|
return super().rate_at(epoch) * self.anneal**epoch
|
||||||
|
|
||||||
|
class DumbLearner(AnnealingLearner):
|
||||||
|
# this is my own awful contraption. it's not really "SGD with restarts".
|
||||||
|
def __init__(self, optim, epochs=100, rate=None, halve_every=10, restarts=0, restart_advance=20, callback=None):
|
||||||
|
self.restart_epochs = int(epochs)
|
||||||
|
self.restarts = int(restarts)
|
||||||
|
self.restart_advance = float(restart_advance)
|
||||||
|
self.restart_callback = callback
|
||||||
|
epochs = self.restart_epochs * (self.restarts + 1)
|
||||||
|
super().__init__(optim, epochs, rate, halve_every)
|
||||||
|
|
||||||
|
def rate_at(self, epoch):
|
||||||
|
sub_epoch = epoch % self.restart_epochs
|
||||||
|
restart = epoch // self.restart_epochs
|
||||||
|
return super().rate_at(sub_epoch) * (self.anneal**self.restart_advance)**restart
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
if not super().next():
|
||||||
|
return False
|
||||||
|
sub_epoch = self.epoch % self.restart_epochs
|
||||||
|
restart = self.epoch // self.restart_epochs
|
||||||
|
if restart > 0 and sub_epoch == 0:
|
||||||
|
if self.restart_callback is not None:
|
||||||
|
self.restart_callback(restart)
|
||||||
|
return True
|
||||||
|
|
||||||
def multiresnet(x, width, depth, block=2, multi=1,
|
def multiresnet(x, width, depth, block=2, multi=1,
|
||||||
activation=Relu, style='batchless',
|
activation=Relu, style='batchless',
|
||||||
init=init_he_normal):
|
init=init_he_normal):
|
||||||
|
@ -725,12 +797,16 @@ def run(program, args=[]):
|
||||||
print(str(node)+sep+('\n'+str(node)+sep).join(children))
|
print(str(node)+sep+('\n'+str(node)+sep).join(children))
|
||||||
log('parameters', model.param_count)
|
log('parameters', model.param_count)
|
||||||
|
|
||||||
|
#
|
||||||
|
|
||||||
training = config.epochs > 0 and config.restarts >= 0
|
training = config.epochs > 0 and config.restarts >= 0
|
||||||
|
|
||||||
if config.fn_load is not None:
|
if config.fn_load is not None:
|
||||||
log('loading weights', config.fn_load)
|
log('loading weights', config.fn_load)
|
||||||
model.load_weights(config.fn_load)
|
model.load_weights(config.fn_load)
|
||||||
|
|
||||||
|
#
|
||||||
|
|
||||||
if config.optim == 'adam':
|
if config.optim == 'adam':
|
||||||
assert not config.nesterov, "unimplemented"
|
assert not config.nesterov, "unimplemented"
|
||||||
optim = Adam()
|
optim = Adam()
|
||||||
|
@ -742,6 +818,22 @@ def run(program, args=[]):
|
||||||
else:
|
else:
|
||||||
raise Exception('unknown optimizer', config.optim)
|
raise Exception('unknown optimizer', config.optim)
|
||||||
|
|
||||||
|
def rscb(restart):
|
||||||
|
measure_error() # declared later...
|
||||||
|
log("restarting", restart)
|
||||||
|
if config.restart_optim:
|
||||||
|
optim.reset()
|
||||||
|
|
||||||
|
#
|
||||||
|
|
||||||
|
learner = DumbLearner(optim, epochs=config.epochs, rate=config.learn,
|
||||||
|
halve_every=config.learn_halve_every,
|
||||||
|
restarts=config.restarts, restart_advance=config.learn_restart_advance,
|
||||||
|
callback=rscb)
|
||||||
|
log("final learning rate", "{:10.8f}".format(learner.final_rate))
|
||||||
|
|
||||||
|
#
|
||||||
|
|
||||||
def lookup_loss(maybe_name):
|
def lookup_loss(maybe_name):
|
||||||
if isinstance(maybe_name, Loss):
|
if isinstance(maybe_name, Loss):
|
||||||
return maybe_name
|
return maybe_name
|
||||||
|
@ -754,14 +846,7 @@ def run(program, args=[]):
|
||||||
loss = lookup_loss(config.loss)
|
loss = lookup_loss(config.loss)
|
||||||
mloss = lookup_loss(config.mloss) if config.mloss else loss
|
mloss = lookup_loss(config.mloss) if config.mloss else loss
|
||||||
|
|
||||||
anneal = 0.5**(1/config.learn_halve_every)
|
ritual = Ritual(learner=learner, loss=loss, mloss=mloss)
|
||||||
ritual = Ritual(optim=optim,
|
|
||||||
learn_rate=config.learn, learn_anneal=anneal,
|
|
||||||
learn_advance=config.learn_restart_advance,
|
|
||||||
loss=loss, mloss=mloss)
|
|
||||||
|
|
||||||
learn_end = config.learn * (anneal**config.learn_restart_advance)**config.restarts * anneal**(config.epochs - 1)
|
|
||||||
log("final learning rate", "{:10.8f}".format(learn_end))
|
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
|
||||||
|
@ -788,31 +873,27 @@ def run(program, args=[]):
|
||||||
train_losses.append(train_err)
|
train_losses.append(train_err)
|
||||||
valid_losses.append(valid_err)
|
valid_losses.append(valid_err)
|
||||||
|
|
||||||
for i in range(config.restarts + 1):
|
|
||||||
measure_error()
|
measure_error()
|
||||||
|
|
||||||
if i > 0:
|
|
||||||
log("restarting", i)
|
|
||||||
ritual.restart(optim=config.restart_optim)
|
|
||||||
|
|
||||||
assert inputs.shape[0] % config.batch_size == 0, \
|
assert inputs.shape[0] % config.batch_size == 0, \
|
||||||
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
|
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
|
||||||
for e in range(config.epochs):
|
while learner.next():
|
||||||
indices = np.arange(inputs.shape[0])
|
indices = np.arange(inputs.shape[0])
|
||||||
np.random.shuffle(indices)
|
np.random.shuffle(indices)
|
||||||
shuffled_inputs = inputs[indices] / x_scale
|
shuffled_inputs = inputs[indices] / x_scale
|
||||||
shuffled_outputs = outputs[indices] / y_scale
|
shuffled_outputs = outputs[indices] / y_scale
|
||||||
|
|
||||||
ritual.prepare(e)
|
|
||||||
#log("learning rate", "{:10.8f}".format(ritual.optim.alpha))
|
|
||||||
|
|
||||||
avg_loss, losses = ritual.train_batched(model,
|
avg_loss, losses = ritual.train_batched(model,
|
||||||
shuffled_inputs, shuffled_outputs,
|
shuffled_inputs, shuffled_outputs,
|
||||||
config.batch_size,
|
config.batch_size,
|
||||||
return_losses=True)
|
return_losses=True)
|
||||||
log("average loss", "{:11.7f}".format(avg_loss))
|
|
||||||
batch_losses += losses
|
batch_losses += losses
|
||||||
|
|
||||||
|
#log("learning rate", "{:10.8f}".format(learner.rate))
|
||||||
|
#log("average loss", "{:11.7f}".format(avg_loss))
|
||||||
|
fmt = "epoch {:4.0f}, rate {:10.8f}, loss {:11.7f}"
|
||||||
|
log("info", fmt.format(learner.epoch + 1, learner.rate, avg_loss))
|
||||||
|
|
||||||
measure_error()
|
measure_error()
|
||||||
|
|
||||||
if config.fn_save is not None:
|
if config.fn_save is not None:
|
||||||
|
@ -821,6 +902,8 @@ def run(program, args=[]):
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
|
|
||||||
|
# this is just an example/test of how to predict a single output;
|
||||||
|
# it doesn't measure the quality of the network or anything.
|
||||||
a = (192, 128, 64)
|
a = (192, 128, 64)
|
||||||
b = (64, 128, 192)
|
b = (64, 128, 192)
|
||||||
X = np.expand_dims(np.hstack((a, b)), 0) / x_scale
|
X = np.expand_dims(np.hstack((a, b)), 0) / x_scale
|
||||||
|
|
Loading…
Reference in a new issue