From 8fc1c198b4d37e80f9b19e6818f4ca5ddd8ad8bd Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Wed, 8 Feb 2017 18:10:25 -0800 Subject: [PATCH] . --- optim_nn.py | 115 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 90 insertions(+), 25 deletions(-) diff --git a/optim_nn.py b/optim_nn.py index 2a0ba34..35e0e87 100644 --- a/optim_nn.py +++ b/optim_nn.py @@ -12,6 +12,7 @@ from scipy.special import expit as sigmoid # used for numbering layers like Keras: from collections import defaultdict +_layer_counters = defaultdict(lambda: 0) # Initializations @@ -142,8 +143,6 @@ class Adam(Optimizer): # Abstract Layers -_layer_counters = defaultdict(lambda: 0) - class Layer: def __init__(self): self.parents = [] @@ -297,7 +296,6 @@ class Affine(Layer): class Sigmoid(Layer): # aka Logistic def F(self, X): - from scipy.special import expit as sigmoid self.sig = sigmoid(X) return X * self.sig @@ -322,6 +320,7 @@ class Relu(Layer): class Elu(Layer): # paper: https://arxiv.org/abs/1511.07289 + def __init__(self, alpha=1): super().__init__() self.alpha = nf(alpha) @@ -337,6 +336,7 @@ class Elu(Layer): class GeluApprox(Layer): # paper: https://arxiv.org/abs/1606.08415 # plot: https://www.desmos.com/calculator/ydzgtccsld + def F(self, X): self.a = 1.704 * X self.sig = sigmoid(self.a) @@ -407,7 +407,10 @@ class DenseOneLess(Dense): np.fill_diagonal(self.dcoeffs, 0) return dX -class LayerNorm(Layer): # TODO: inherit Affine instead? +class LayerNorm(Layer): + # paper: https://arxiv.org/abs/1607.06450 + # my implementation may be incorrect. + def __init__(self, eps=1e-3, axis=-1): super().__init__() self.eps = nf(eps) @@ -556,13 +559,17 @@ class Ritual: # i'm just making up names at this point self.learner.optim.update(self.model.dW, self.model.W) def prepare(self, model): + self.en = 0 + self.bn = 0 self.model = model def train_batched(self, inputs, outputs, batch_size, return_losses=False): + self.en += 1 cumsum_loss = 0 batch_count = inputs.shape[0] // batch_size losses = [] for b in range(batch_count): + self.bn += 1 bi = b * batch_size batch_inputs = inputs[ bi:bi+batch_size] batch_outputs = outputs[bi:bi+batch_size] @@ -587,6 +594,7 @@ class Ritual: # i'm just making up names at this point def stochastic_multiply(W, gamma=0.5, allow_negation=True): # paper: https://arxiv.org/abs/1606.01981 + assert W.ndim == 1, W.ndim assert 0 < gamma < 1, gamma size = len(W) @@ -632,7 +640,28 @@ class StochMRitual(Ritual): for layer in self.model.ordered_nodes: if isinstance(layer, Dense): np.clip(layer.W, -layer.std * f, layer.std * f, out=layer.W) - # np.clip(layer.W, -1, 1, out=layer.W) + # np.clip(layer.W, -1, 1, out=layer.W) + +class NoisyRitual(Ritual): + def __init__(self, learner=None, loss=None, mloss=None, + input_noise=0, output_noise=0, gradient_noise=0): + self.input_noise = nf(input_noise) # TODO: implement + self.output_noise = nf(output_noise) # TODO: implement + self.gradient_noise = nf(gradient_noise) + super().__init__(learner, loss, mloss) + + def update(self): + # gradient noise paper: https://arxiv.org/abs/1511.06807 + if self.gradient_noise > 0: + size = len(self.model.dW) + gamma = 0.55 + s = self.gradient_noise / (1 + self.bn) ** gamma + # experiments: + #s = np.sqrt(self.learner.rate) + #s = np.square(self.learner.rate) + #s = self.learner.rate / self.en + self.model.dW += np.random.normal(0, s, size=size) + super().update() class Learner: per_batch = False @@ -733,28 +762,46 @@ def cosmod(x): class SGDR(Learner): # Stochastic Gradient Descent with Restarts # paper: https://arxiv.org/abs/1608.03983 - # NOTE: this is not a complete implementation. + # NOTE: this is missing a couple features. + per_batch = True - def __init__(self, optim, epochs=100, rate=None, restarts=0, restart_decay=0.5, callback=None): + def __init__(self, optim, epochs=100, rate=None, + restarts=0, restart_decay=0.5, callback=None, + expando=None): self.restart_epochs = int(epochs) self.decay = float(restart_decay) self.restarts = int(restarts) self.restart_callback = callback - epochs = self.restart_epochs * (self.restarts + 1) + # TODO: rename expando to something not insane + self.expando = expando if expando is not None else lambda i: 1 + + self.splits = [] + epochs = 0 + for i in range(0, self.restarts + 1): + split = epochs + int(self.restart_epochs * self.expando(i)) + self.splits.append(split) + epochs = split super().__init__(optim, epochs, rate) + def split_num(self, epoch): + shit = [0] + self.splits # hack + for i in range(0, len(self.splits)): + if epoch < self.splits[i]: + sub_epoch = epoch - shit[i] + next_restart = self.splits[i] - shit[i] + return i, sub_epoch, next_restart + raise Exception('this should never happen.') + def rate_at(self, epoch): - sub_epoch = epoch % self.restart_epochs - x = sub_epoch / self.restart_epochs - restart = epoch // self.restart_epochs + restart, sub_epoch, next_restart = self.split_num(epoch) + x = sub_epoch / next_restart return self.start_rate * self.decay**restart * cosmod(x) def next(self): if not super().next(): return False - sub_epoch = self.epoch % self.restart_epochs - restart = self.epoch // self.restart_epochs + restart, sub_epoch, next_restart = self.split_num(self.epoch) if restart > 0 and sub_epoch == 0: if self.restart_callback is not None: self.restart_callback(restart) @@ -789,6 +836,7 @@ def multiresnet(x, width, depth, block=2, multi=1, z.feed(merger) y = merger elif style == 'onelesssum': + # this is my own awful contraption. is_last = d + 1 == depth needs_sum = not is_last or multi > 1 skip = y @@ -845,16 +893,17 @@ def run(program, args=[]): optim = 'adam', nesterov = False, # only used with SGD or Adam - momentum = 0.33, # only used with SGD + momentum = 0.50, # only used with SGD # learning parameters - learner = 'SGDR', + learner = 'sgdr', learn = 1e-2, - epochs = 24, - restarts = 2, - learn_decay = 0.25, # only used with SGDR learn_halve_every = 16, # unused with SGDR learn_restart_advance = 16, # unused with SGDR + epochs = 12, + restarts = 2, + restart_decay = 1, # only used with SGDR + expando = lambda i: i + 1, # misc batch_size = 64, @@ -866,10 +915,13 @@ def run(program, args=[]): train_compare = 0.0000508, valid_compare = 0.0000678, - ritual = None, + ritual = 'default', ) - config.pprint() + for k in ['parallel_style', 'optim', 'learner', 'ritual']: + config[k] = config[k].lower() + + #config.pprint() # toy CIE-2000 data from ml.cie_mlp_data import rgbcompare, input_samples, output_samples, \ @@ -933,17 +985,27 @@ def run(program, args=[]): # - if config.learner == 'SGDR': + if config.learner == 'sgdr': + expando = config.expando if 'expando' in config else None learner = SGDR(optim, epochs=config.epochs, rate=config.learn, - restart_decay=config.learn_decay, restarts=config.restarts, - callback=rscb) + restart_decay=config.restart_decay, restarts=config.restarts, + callback=rscb, expando=expando) # final learning rate isn't of interest here; it's gonna be close to 0. - else: + log('total epochs:', learner.epochs) + elif config.learner == 'anneal': + learner = AnnealingLearner(optim, epochs=config.epochs, rate=config.learn, + halve_every=config.learn_halve_every) + elif config.learner == 'dumb': learner = DumbLearner(optim, epochs=config.epochs, rate=config.learn, halve_every=config.learn_halve_every, restarts=config.restarts, restart_advance=config.learn_restart_advance, callback=rscb) log("final learning rate", "{:10.8f}".format(learner.final_rate)) + elif config.learner == 'sgd': + learner = Learner(optim, epochs=config.epochs, rate=config.learn) + log("final learning rate", "{:10.8f}".format(learner.final_rate)) + else: + raise Exception('unknown learner', config.learner) # @@ -961,10 +1023,13 @@ def run(program, args=[]): loss = lookup_loss(config.loss) mloss = lookup_loss(config.mloss) if config.mloss else loss - if config.ritual == None: + if config.ritual == 'default': ritual = Ritual(learner=learner, loss=loss, mloss=mloss) elif config.ritual == 'stochm': ritual = StochMRitual(learner=learner, loss=loss, mloss=mloss) + elif config.ritual == 'noisy': + ritual = NoisyRitual(learner=learner, loss=loss, mloss=mloss, + gradient_noise=0.01) else: raise Exception('unknown ritual', config.ritual)