This commit is contained in:
Connor Olding 2017-01-10 04:27:49 -08:00
parent d299520fd9
commit e28eb0cf06

View file

@ -1,7 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# imports
import numpy as np import numpy as np
nf = np.float32 nf = np.float32
nfa = lambda x: np.array(x, dtype=nf) nfa = lambda x: np.array(x, dtype=nf)
@ -31,7 +29,7 @@ class SquaredHalved(Loss):
class Optimizer: class Optimizer:
def __init__(self, alpha=0.1): def __init__(self, alpha=0.1):
self.alpha = nfa(alpha) self.alpha = nf(alpha)
self.reset() self.reset()
def reset(self): def reset(self):
@ -43,9 +41,33 @@ class Optimizer:
def update(self, dW, W): def update(self, dW, W):
W += self.compute(dW, W) W += self.compute(dW, W)
# the following optimizers are blatantly ripped from tiny-dnn: # the following optimizers are blatantly lifted from tiny-dnn:
# https://github.com/tiny-dnn/tiny-dnn/blob/master/tiny_dnn/optimizers/optimizer.h # https://github.com/tiny-dnn/tiny-dnn/blob/master/tiny_dnn/optimizers/optimizer.h
class Momentum(Optimizer):
def __init__(self, alpha=0.01, lamb=0, mu=0.9, nesterov=False):
self.alpha = np.asfarray(alpha) # learning rate
self.lamb = np.asfarray(lamb) # weight decay
self.mu = np.asfarray(mu) # momentum
self.nesterov = bool(nesterov)
self.reset()
def reset(self):
self.dWprev = None
def compute(self, dW, W):
if self.dWprev is None:
#self.dWprev = np.zeros_like(dW)
self.dWprev = np.copy(dW)
V = self.mu * self.dWprev - self.alpha * (dW + W * self.lamb)
self.dWprev = V
if self.nesterov:
return self.mu * V - self.alpha * (dW + W * self.lamb)
else:
return V
class Adam(Optimizer): class Adam(Optimizer):
def __init__(self, alpha=0.001, b1=0.9, b2=0.999, b1_t=0.9, b2_t=0.999, eps=1e-8): def __init__(self, alpha=0.001, b1=0.9, b2=0.999, b1_t=0.9, b2_t=0.999, eps=1e-8):
self.alpha = nf(alpha) # learning rate self.alpha = nf(alpha) # learning rate
@ -107,7 +129,7 @@ class Layer:
raise NotImplementedError("unimplemented", self) raise NotImplementedError("unimplemented", self)
def do_feed(self, child): def do_feed(self, child):
pass self.children.append(child)
def be_fed(self, parent): def be_fed(self, parent):
self.parents.append(parent) self.parents.append(parent)
@ -154,7 +176,6 @@ class Layer:
if not child.compatible(self): if not child.compatible(self):
fmt = "{} is incompatible with {}: shape mismatch: {} vs. {}" fmt = "{} is incompatible with {}: shape mismatch: {} vs. {}"
raise Exception(fmt.format(self, child, self.output_shape, child.input_shape)) raise Exception(fmt.format(self, child, self.output_shape, child.input_shape))
self.children.append(child)
self.do_feed(child) self.do_feed(child)
child.be_fed(self) child.be_fed(self)
return child return child
@ -375,14 +396,14 @@ class Model:
#return lut[input_node] # meaningless value #return lut[input_node] # meaningless value
return self.dW return self.dW
def load_model(self, fn): def load_weights(self, fn):
# seemingly compatible with keras models at the moment # seemingly compatible with keras models at the moment
import h5py import h5py
f = h5py.File(fn) f = h5py.File(fn)
loadweights = {} weights = {}
def visitor(name, obj): def visitor(name, obj):
if isinstance(obj, h5py.Dataset): if isinstance(obj, h5py.Dataset):
loadweights[name.split('/')[-1]] = nfa(obj[:]) weights[name.split('/')[-1]] = nfa(obj[:])
f.visititems(visitor) f.visititems(visitor)
f.close() f.close()
@ -390,10 +411,10 @@ class Model:
for i in range(len(denses)): for i in range(len(denses)):
a, b = i, i + 1 a, b = i, i + 1
b_name = "dense_{}".format(b) b_name = "dense_{}".format(b)
denses[a].coeffs = loadweights[b_name+'_W'].T denses[a].coeffs = weights[b_name+'_W'].T
denses[a].biases = np.expand_dims(loadweights[b_name+'_b'], -1) denses[a].biases = np.expand_dims(weights[b_name+'_b'], -1)
def save_model(self, fn, overwrite=False): def save_weights(self, fn, overwrite=False):
raise NotImplementedError("unimplemented", self) raise NotImplementedError("unimplemented", self)
if __name__ == '__main__': if __name__ == '__main__':
@ -488,13 +509,18 @@ if __name__ == '__main__':
if not training: if not training:
model.load_weights(config.fn) model.load_weights(config.fn)
assert config.optim == 'adam' if config.optim == 'adam':
if config.nesterov: assert not config.nesterov, "unimplemented"
assert False, "unimplemented"
else:
optim = Adam() optim = Adam()
elif config.optim == 'sgd':
if config.momentum != 0:
optim = Momentum(mu=config.momentum, nesterov=config.nesterov)
else:
optim = Optimizer()
else:
raise Exception('unknown optimizer', config.optim)
assert config.loss == 'mse' assert config.loss == 'mse', 'unknown loss function'
loss = SquaredHalved() loss = SquaredHalved()
LR = config.LR LR = config.LR