.
This commit is contained in:
parent
d299520fd9
commit
e28eb0cf06
1 changed files with 43 additions and 17 deletions
60
optim_nn.py
60
optim_nn.py
|
@ -1,7 +1,5 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
# imports
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
nf = np.float32
|
nf = np.float32
|
||||||
nfa = lambda x: np.array(x, dtype=nf)
|
nfa = lambda x: np.array(x, dtype=nf)
|
||||||
|
@ -31,7 +29,7 @@ class SquaredHalved(Loss):
|
||||||
|
|
||||||
class Optimizer:
|
class Optimizer:
|
||||||
def __init__(self, alpha=0.1):
|
def __init__(self, alpha=0.1):
|
||||||
self.alpha = nfa(alpha)
|
self.alpha = nf(alpha)
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
@ -43,9 +41,33 @@ class Optimizer:
|
||||||
def update(self, dW, W):
|
def update(self, dW, W):
|
||||||
W += self.compute(dW, W)
|
W += self.compute(dW, W)
|
||||||
|
|
||||||
# the following optimizers are blatantly ripped from tiny-dnn:
|
# the following optimizers are blatantly lifted from tiny-dnn:
|
||||||
# https://github.com/tiny-dnn/tiny-dnn/blob/master/tiny_dnn/optimizers/optimizer.h
|
# https://github.com/tiny-dnn/tiny-dnn/blob/master/tiny_dnn/optimizers/optimizer.h
|
||||||
|
|
||||||
|
class Momentum(Optimizer):
|
||||||
|
def __init__(self, alpha=0.01, lamb=0, mu=0.9, nesterov=False):
|
||||||
|
self.alpha = np.asfarray(alpha) # learning rate
|
||||||
|
self.lamb = np.asfarray(lamb) # weight decay
|
||||||
|
self.mu = np.asfarray(mu) # momentum
|
||||||
|
self.nesterov = bool(nesterov)
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.dWprev = None
|
||||||
|
|
||||||
|
def compute(self, dW, W):
|
||||||
|
if self.dWprev is None:
|
||||||
|
#self.dWprev = np.zeros_like(dW)
|
||||||
|
self.dWprev = np.copy(dW)
|
||||||
|
|
||||||
|
V = self.mu * self.dWprev - self.alpha * (dW + W * self.lamb)
|
||||||
|
self.dWprev = V
|
||||||
|
if self.nesterov:
|
||||||
|
return self.mu * V - self.alpha * (dW + W * self.lamb)
|
||||||
|
else:
|
||||||
|
return V
|
||||||
|
|
||||||
class Adam(Optimizer):
|
class Adam(Optimizer):
|
||||||
def __init__(self, alpha=0.001, b1=0.9, b2=0.999, b1_t=0.9, b2_t=0.999, eps=1e-8):
|
def __init__(self, alpha=0.001, b1=0.9, b2=0.999, b1_t=0.9, b2_t=0.999, eps=1e-8):
|
||||||
self.alpha = nf(alpha) # learning rate
|
self.alpha = nf(alpha) # learning rate
|
||||||
|
@ -107,7 +129,7 @@ class Layer:
|
||||||
raise NotImplementedError("unimplemented", self)
|
raise NotImplementedError("unimplemented", self)
|
||||||
|
|
||||||
def do_feed(self, child):
|
def do_feed(self, child):
|
||||||
pass
|
self.children.append(child)
|
||||||
|
|
||||||
def be_fed(self, parent):
|
def be_fed(self, parent):
|
||||||
self.parents.append(parent)
|
self.parents.append(parent)
|
||||||
|
@ -154,7 +176,6 @@ class Layer:
|
||||||
if not child.compatible(self):
|
if not child.compatible(self):
|
||||||
fmt = "{} is incompatible with {}: shape mismatch: {} vs. {}"
|
fmt = "{} is incompatible with {}: shape mismatch: {} vs. {}"
|
||||||
raise Exception(fmt.format(self, child, self.output_shape, child.input_shape))
|
raise Exception(fmt.format(self, child, self.output_shape, child.input_shape))
|
||||||
self.children.append(child)
|
|
||||||
self.do_feed(child)
|
self.do_feed(child)
|
||||||
child.be_fed(self)
|
child.be_fed(self)
|
||||||
return child
|
return child
|
||||||
|
@ -375,14 +396,14 @@ class Model:
|
||||||
#return lut[input_node] # meaningless value
|
#return lut[input_node] # meaningless value
|
||||||
return self.dW
|
return self.dW
|
||||||
|
|
||||||
def load_model(self, fn):
|
def load_weights(self, fn):
|
||||||
# seemingly compatible with keras models at the moment
|
# seemingly compatible with keras models at the moment
|
||||||
import h5py
|
import h5py
|
||||||
f = h5py.File(fn)
|
f = h5py.File(fn)
|
||||||
loadweights = {}
|
weights = {}
|
||||||
def visitor(name, obj):
|
def visitor(name, obj):
|
||||||
if isinstance(obj, h5py.Dataset):
|
if isinstance(obj, h5py.Dataset):
|
||||||
loadweights[name.split('/')[-1]] = nfa(obj[:])
|
weights[name.split('/')[-1]] = nfa(obj[:])
|
||||||
f.visititems(visitor)
|
f.visititems(visitor)
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
|
@ -390,10 +411,10 @@ class Model:
|
||||||
for i in range(len(denses)):
|
for i in range(len(denses)):
|
||||||
a, b = i, i + 1
|
a, b = i, i + 1
|
||||||
b_name = "dense_{}".format(b)
|
b_name = "dense_{}".format(b)
|
||||||
denses[a].coeffs = loadweights[b_name+'_W'].T
|
denses[a].coeffs = weights[b_name+'_W'].T
|
||||||
denses[a].biases = np.expand_dims(loadweights[b_name+'_b'], -1)
|
denses[a].biases = np.expand_dims(weights[b_name+'_b'], -1)
|
||||||
|
|
||||||
def save_model(self, fn, overwrite=False):
|
def save_weights(self, fn, overwrite=False):
|
||||||
raise NotImplementedError("unimplemented", self)
|
raise NotImplementedError("unimplemented", self)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -488,13 +509,18 @@ if __name__ == '__main__':
|
||||||
if not training:
|
if not training:
|
||||||
model.load_weights(config.fn)
|
model.load_weights(config.fn)
|
||||||
|
|
||||||
assert config.optim == 'adam'
|
if config.optim == 'adam':
|
||||||
if config.nesterov:
|
assert not config.nesterov, "unimplemented"
|
||||||
assert False, "unimplemented"
|
|
||||||
else:
|
|
||||||
optim = Adam()
|
optim = Adam()
|
||||||
|
elif config.optim == 'sgd':
|
||||||
|
if config.momentum != 0:
|
||||||
|
optim = Momentum(mu=config.momentum, nesterov=config.nesterov)
|
||||||
|
else:
|
||||||
|
optim = Optimizer()
|
||||||
|
else:
|
||||||
|
raise Exception('unknown optimizer', config.optim)
|
||||||
|
|
||||||
assert config.loss == 'mse'
|
assert config.loss == 'mse', 'unknown loss function'
|
||||||
loss = SquaredHalved()
|
loss = SquaredHalved()
|
||||||
|
|
||||||
LR = config.LR
|
LR = config.LR
|
||||||
|
|
Loading…
Reference in a new issue