From f12e408c7e0eebe7d19f9b051c80e29b9f2cb1a5 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Thu, 12 Jan 2017 14:45:07 -0800 Subject: [PATCH] . --- optim_nn.py | 123 ++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 85 insertions(+), 38 deletions(-) diff --git a/optim_nn.py b/optim_nn.py index 357b0fc..87acc4f 100644 --- a/optim_nn.py +++ b/optim_nn.py @@ -98,7 +98,7 @@ class Momentum(Optimizer): self.dWprev = np.copy(dW) V = self.mu * self.dWprev - self.alpha * (dW + W * self.lamb) - self.dWprev = V + self.dWprev[:] = V if self.nesterov: return self.mu * V - self.alpha * (dW + W * self.lamb) else: @@ -412,7 +412,7 @@ class Model: offset += node.size def traverse(self, nodes, node): - if node == x: + if node == self.x: return [node] for parent in node.parents: if parent not in nodes: @@ -465,6 +465,60 @@ class Model: def save_weights(self, fn, overwrite=False): raise NotImplementedError("unimplemented", self) +class Ritual: + def __init__(self, + optim=None, + learn_rate=1e-3, learn_anneal=1, learn_advance=0, + loss=None, mloss=None): + self.optim = optim if optim is not None else SGD() + self.loss = loss if loss is not None else Squared() + self.mloss = mloss if mloss is not None else loss + self.learn_rate = nf(learn_rate) + self.learn_anneal = nf(learn_anneal) + self.learn_advance = nf(learn_advance) + + def measure(self, residual): + return self.mloss.mean(residual) + + def derive(self, residual): + return self.loss.dmean(residual) + + def update(self, dW, W): + self.optim.update(dW, W) + + def prepare(self, epoch): + self.optim.alpha = self.learn_rate * self.learn_anneal**epoch + + def restart(self, optim=False): + self.learn_rate *= self.learn_anneal**self.learn_advance + if optim: + self.optim.reset() + + def train_batched(self, model, inputs, outputs, batch_size, return_losses=False): + cumsum_loss = 0 + batch_count = inputs.shape[0] // batch_size + losses = [] + for b in range(batch_count): + bi = b * batch_size + batch_inputs = inputs[ bi:bi+batch_size] + batch_outputs = outputs[bi:bi+batch_size] + + predicted = model.forward(batch_inputs) + residual = predicted - batch_outputs + + model.backward(self.derive(residual)) + self.update(model.dW, model.W) + + batch_loss = self.measure(residual) + cumsum_loss += batch_loss + if return_losses: + losses.append(batch_loss) + avg_loss = cumsum_loss / batch_count + if return_losses: + return avg_loss, losses + else: + return avg_loss + def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batchless'): y = x last_size = x.output_shape[0] @@ -516,7 +570,7 @@ def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batch return y -if __name__ == '__main__': +def run(program, args=[]): import sys lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs) def log(left, right): @@ -527,6 +581,7 @@ if __name__ == '__main__': from dotmap import DotMap config = DotMap( fn = 'ml/cie_mlp_min.h5', + log_fn = 'losses.npz', # multi-residual network parameters res_width = 49, @@ -553,7 +608,6 @@ if __name__ == '__main__': batch_size = 64, init = 'he_normal', loss = SomethingElse(4/3), - log_fn = 'losses.npz', mloss = 'mse', restart_optim = True, # restarts also reset internal state of optimizer unsafe = True, # aka gotta go fast mode @@ -632,11 +686,14 @@ if __name__ == '__main__': loss = lookup_loss(config.loss) mloss = lookup_loss(config.mloss) if config.mloss else loss - LR = config.LR - LRprod = 0.5**(1/config.LR_halve_every) + anneal = 0.5**(1/config.LR_halve_every) + ritual = Ritual(optim=optim, + learn_rate=config.LR, learn_anneal=anneal, + learn_advance=config.LR_restart_advance, + loss=loss, mloss=mloss) - LRE = LR * (LRprod**config.LR_restart_advance)**config.restarts * LRprod**(config.epochs - 1) - log("final learning rate", "{:10.8f}".format(LRE)) + learn_end = config.LR * (anneal**config.LR_restart_advance)**config.restarts * anneal**(config.epochs - 1) + log("final learning rate", "{:10.8f}".format(learn_end)) # Training @@ -645,24 +702,21 @@ if __name__ == '__main__': valid_losses = [] def measure_error(): - # log("weight mean", "{:11.7f}".format(np.mean(model.W))) - # log("weight var", "{:11.7f}".format(np.var(model.W))) - def print_error(name, inputs, outputs, comparison=None): predicted = model.forward(inputs) residual = predicted - outputs - err = mloss.mean(residual) + err = ritual.measure(residual) log(name + " loss", "{:11.7f}".format(err)) if comparison: log("improvement", "{:+7.2f}%".format((comparison / err - 1) * 100)) return err train_err = print_error("train", - inputs / x_scale, outputs / y_scale, - config.train_compare) + inputs / x_scale, outputs / y_scale, + config.train_compare) valid_err = print_error("valid", - valid_inputs / x_scale, valid_outputs / y_scale, - config.valid_compare) + valid_inputs / x_scale, valid_outputs / y_scale, + config.valid_compare) train_losses.append(train_err) valid_losses.append(valid_err) @@ -671,38 +725,25 @@ if __name__ == '__main__': if i > 0: log("restarting", i) - LR *= LRprod**config.LR_restart_advance - if config.restart_optim: - optim.reset() + ritual.restart(optim=config.restart_optim) assert inputs.shape[0] % config.batch_size == 0, \ "inputs is not evenly divisible by batch_size" # TODO: lift this restriction - batch_count = inputs.shape[0] // config.batch_size for e in range(config.epochs): indices = np.arange(inputs.shape[0]) np.random.shuffle(indices) shuffled_inputs = inputs[indices] / x_scale shuffled_outputs = outputs[indices] / y_scale - optim.alpha = LR * LRprod**e - #log("learning rate", "{:10.8f}".format(optim.alpha)) + ritual.prepare(e) + #log("learning rate", "{:10.8f}".format(ritual.optim.alpha)) - cumsum_loss = 0 - for b in range(batch_count): - bi = b * config.batch_size - batch_inputs = shuffled_inputs[ bi:bi+config.batch_size] - batch_outputs = shuffled_outputs[bi:bi+config.batch_size] - - predicted = model.forward(batch_inputs) - residual = predicted - batch_outputs - dW = model.backward(loss.dmean(residual)) - optim.update(dW, model.W) - - # note: we don't actually need this for training, only monitoring. - batch_loss = mloss.mean(residual) - cumsum_loss += batch_loss - batch_losses.append(batch_loss) - #log("average loss", "{:11.7f}".format(cumsum_loss / batch_count)) + avg_loss, losses = ritual.train_batched(model, + shuffled_inputs, shuffled_outputs, + config.batch_size, + return_losses=True) + log("average loss", "{:11.7f}".format(avg_loss)) + batch_losses += losses measure_error() @@ -723,3 +764,9 @@ if __name__ == '__main__': batch_losses=nfa(batch_losses), train_losses=nfa(train_losses), valid_losses=nfa(valid_losses)) + + return 0 + +if __name__ == '__main__': + import sys + sys.exit(run(sys.argv[0], sys.argv[1:]))