This commit is contained in:
Connor Olding 2017-01-12 14:45:07 -08:00
parent 02e03ad85e
commit f12e408c7e

View file

@ -98,7 +98,7 @@ class Momentum(Optimizer):
self.dWprev = np.copy(dW)
V = self.mu * self.dWprev - self.alpha * (dW + W * self.lamb)
self.dWprev = V
self.dWprev[:] = V
if self.nesterov:
return self.mu * V - self.alpha * (dW + W * self.lamb)
else:
@ -412,7 +412,7 @@ class Model:
offset += node.size
def traverse(self, nodes, node):
if node == x:
if node == self.x:
return [node]
for parent in node.parents:
if parent not in nodes:
@ -465,6 +465,60 @@ class Model:
def save_weights(self, fn, overwrite=False):
raise NotImplementedError("unimplemented", self)
class Ritual:
def __init__(self,
optim=None,
learn_rate=1e-3, learn_anneal=1, learn_advance=0,
loss=None, mloss=None):
self.optim = optim if optim is not None else SGD()
self.loss = loss if loss is not None else Squared()
self.mloss = mloss if mloss is not None else loss
self.learn_rate = nf(learn_rate)
self.learn_anneal = nf(learn_anneal)
self.learn_advance = nf(learn_advance)
def measure(self, residual):
return self.mloss.mean(residual)
def derive(self, residual):
return self.loss.dmean(residual)
def update(self, dW, W):
self.optim.update(dW, W)
def prepare(self, epoch):
self.optim.alpha = self.learn_rate * self.learn_anneal**epoch
def restart(self, optim=False):
self.learn_rate *= self.learn_anneal**self.learn_advance
if optim:
self.optim.reset()
def train_batched(self, model, inputs, outputs, batch_size, return_losses=False):
cumsum_loss = 0
batch_count = inputs.shape[0] // batch_size
losses = []
for b in range(batch_count):
bi = b * batch_size
batch_inputs = inputs[ bi:bi+batch_size]
batch_outputs = outputs[bi:bi+batch_size]
predicted = model.forward(batch_inputs)
residual = predicted - batch_outputs
model.backward(self.derive(residual))
self.update(model.dW, model.W)
batch_loss = self.measure(residual)
cumsum_loss += batch_loss
if return_losses:
losses.append(batch_loss)
avg_loss = cumsum_loss / batch_count
if return_losses:
return avg_loss, losses
else:
return avg_loss
def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batchless'):
y = x
last_size = x.output_shape[0]
@ -516,7 +570,7 @@ def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batch
return y
if __name__ == '__main__':
def run(program, args=[]):
import sys
lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs)
def log(left, right):
@ -527,6 +581,7 @@ if __name__ == '__main__':
from dotmap import DotMap
config = DotMap(
fn = 'ml/cie_mlp_min.h5',
log_fn = 'losses.npz',
# multi-residual network parameters
res_width = 49,
@ -553,7 +608,6 @@ if __name__ == '__main__':
batch_size = 64,
init = 'he_normal',
loss = SomethingElse(4/3),
log_fn = 'losses.npz',
mloss = 'mse',
restart_optim = True, # restarts also reset internal state of optimizer
unsafe = True, # aka gotta go fast mode
@ -632,11 +686,14 @@ if __name__ == '__main__':
loss = lookup_loss(config.loss)
mloss = lookup_loss(config.mloss) if config.mloss else loss
LR = config.LR
LRprod = 0.5**(1/config.LR_halve_every)
anneal = 0.5**(1/config.LR_halve_every)
ritual = Ritual(optim=optim,
learn_rate=config.LR, learn_anneal=anneal,
learn_advance=config.LR_restart_advance,
loss=loss, mloss=mloss)
LRE = LR * (LRprod**config.LR_restart_advance)**config.restarts * LRprod**(config.epochs - 1)
log("final learning rate", "{:10.8f}".format(LRE))
learn_end = config.LR * (anneal**config.LR_restart_advance)**config.restarts * anneal**(config.epochs - 1)
log("final learning rate", "{:10.8f}".format(learn_end))
# Training
@ -645,24 +702,21 @@ if __name__ == '__main__':
valid_losses = []
def measure_error():
# log("weight mean", "{:11.7f}".format(np.mean(model.W)))
# log("weight var", "{:11.7f}".format(np.var(model.W)))
def print_error(name, inputs, outputs, comparison=None):
predicted = model.forward(inputs)
residual = predicted - outputs
err = mloss.mean(residual)
err = ritual.measure(residual)
log(name + " loss", "{:11.7f}".format(err))
if comparison:
log("improvement", "{:+7.2f}%".format((comparison / err - 1) * 100))
return err
train_err = print_error("train",
inputs / x_scale, outputs / y_scale,
config.train_compare)
inputs / x_scale, outputs / y_scale,
config.train_compare)
valid_err = print_error("valid",
valid_inputs / x_scale, valid_outputs / y_scale,
config.valid_compare)
valid_inputs / x_scale, valid_outputs / y_scale,
config.valid_compare)
train_losses.append(train_err)
valid_losses.append(valid_err)
@ -671,38 +725,25 @@ if __name__ == '__main__':
if i > 0:
log("restarting", i)
LR *= LRprod**config.LR_restart_advance
if config.restart_optim:
optim.reset()
ritual.restart(optim=config.restart_optim)
assert inputs.shape[0] % config.batch_size == 0, \
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
batch_count = inputs.shape[0] // config.batch_size
for e in range(config.epochs):
indices = np.arange(inputs.shape[0])
np.random.shuffle(indices)
shuffled_inputs = inputs[indices] / x_scale
shuffled_outputs = outputs[indices] / y_scale
optim.alpha = LR * LRprod**e
#log("learning rate", "{:10.8f}".format(optim.alpha))
ritual.prepare(e)
#log("learning rate", "{:10.8f}".format(ritual.optim.alpha))
cumsum_loss = 0
for b in range(batch_count):
bi = b * config.batch_size
batch_inputs = shuffled_inputs[ bi:bi+config.batch_size]
batch_outputs = shuffled_outputs[bi:bi+config.batch_size]
predicted = model.forward(batch_inputs)
residual = predicted - batch_outputs
dW = model.backward(loss.dmean(residual))
optim.update(dW, model.W)
# note: we don't actually need this for training, only monitoring.
batch_loss = mloss.mean(residual)
cumsum_loss += batch_loss
batch_losses.append(batch_loss)
#log("average loss", "{:11.7f}".format(cumsum_loss / batch_count))
avg_loss, losses = ritual.train_batched(model,
shuffled_inputs, shuffled_outputs,
config.batch_size,
return_losses=True)
log("average loss", "{:11.7f}".format(avg_loss))
batch_losses += losses
measure_error()
@ -723,3 +764,9 @@ if __name__ == '__main__':
batch_losses=nfa(batch_losses),
train_losses=nfa(train_losses),
valid_losses=nfa(valid_losses))
return 0
if __name__ == '__main__':
import sys
sys.exit(run(sys.argv[0], sys.argv[1:]))