From e28eb0cf06f3cae6c07ae561913cf9dfa73d9765 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Tue, 10 Jan 2017 04:27:49 -0800 Subject: [PATCH] . --- optim_nn.py | 60 ++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 17 deletions(-) diff --git a/optim_nn.py b/optim_nn.py index 7d50c05..86c50df 100644 --- a/optim_nn.py +++ b/optim_nn.py @@ -1,7 +1,5 @@ #!/usr/bin/env python3 -# imports - import numpy as np nf = np.float32 nfa = lambda x: np.array(x, dtype=nf) @@ -31,7 +29,7 @@ class SquaredHalved(Loss): class Optimizer: def __init__(self, alpha=0.1): - self.alpha = nfa(alpha) + self.alpha = nf(alpha) self.reset() def reset(self): @@ -43,9 +41,33 @@ class Optimizer: def update(self, dW, W): W += self.compute(dW, W) -# the following optimizers are blatantly ripped from tiny-dnn: +# the following optimizers are blatantly lifted from tiny-dnn: # https://github.com/tiny-dnn/tiny-dnn/blob/master/tiny_dnn/optimizers/optimizer.h +class Momentum(Optimizer): + def __init__(self, alpha=0.01, lamb=0, mu=0.9, nesterov=False): + self.alpha = np.asfarray(alpha) # learning rate + self.lamb = np.asfarray(lamb) # weight decay + self.mu = np.asfarray(mu) # momentum + self.nesterov = bool(nesterov) + + self.reset() + + def reset(self): + self.dWprev = None + + def compute(self, dW, W): + if self.dWprev is None: + #self.dWprev = np.zeros_like(dW) + self.dWprev = np.copy(dW) + + V = self.mu * self.dWprev - self.alpha * (dW + W * self.lamb) + self.dWprev = V + if self.nesterov: + return self.mu * V - self.alpha * (dW + W * self.lamb) + else: + return V + class Adam(Optimizer): def __init__(self, alpha=0.001, b1=0.9, b2=0.999, b1_t=0.9, b2_t=0.999, eps=1e-8): self.alpha = nf(alpha) # learning rate @@ -107,7 +129,7 @@ class Layer: raise NotImplementedError("unimplemented", self) def do_feed(self, child): - pass + self.children.append(child) def be_fed(self, parent): self.parents.append(parent) @@ -154,7 +176,6 @@ class Layer: if not child.compatible(self): fmt = "{} is incompatible with {}: shape mismatch: {} vs. {}" raise Exception(fmt.format(self, child, self.output_shape, child.input_shape)) - self.children.append(child) self.do_feed(child) child.be_fed(self) return child @@ -375,14 +396,14 @@ class Model: #return lut[input_node] # meaningless value return self.dW - def load_model(self, fn): + def load_weights(self, fn): # seemingly compatible with keras models at the moment import h5py f = h5py.File(fn) - loadweights = {} + weights = {} def visitor(name, obj): if isinstance(obj, h5py.Dataset): - loadweights[name.split('/')[-1]] = nfa(obj[:]) + weights[name.split('/')[-1]] = nfa(obj[:]) f.visititems(visitor) f.close() @@ -390,10 +411,10 @@ class Model: for i in range(len(denses)): a, b = i, i + 1 b_name = "dense_{}".format(b) - denses[a].coeffs = loadweights[b_name+'_W'].T - denses[a].biases = np.expand_dims(loadweights[b_name+'_b'], -1) + denses[a].coeffs = weights[b_name+'_W'].T + denses[a].biases = np.expand_dims(weights[b_name+'_b'], -1) - def save_model(self, fn, overwrite=False): + def save_weights(self, fn, overwrite=False): raise NotImplementedError("unimplemented", self) if __name__ == '__main__': @@ -488,13 +509,18 @@ if __name__ == '__main__': if not training: model.load_weights(config.fn) - assert config.optim == 'adam' - if config.nesterov: - assert False, "unimplemented" - else: + if config.optim == 'adam': + assert not config.nesterov, "unimplemented" optim = Adam() + elif config.optim == 'sgd': + if config.momentum != 0: + optim = Momentum(mu=config.momentum, nesterov=config.nesterov) + else: + optim = Optimizer() + else: + raise Exception('unknown optimizer', config.optim) - assert config.loss == 'mse' + assert config.loss == 'mse', 'unknown loss function' loss = SquaredHalved() LR = config.LR