.
This commit is contained in:
parent
d299520fd9
commit
e28eb0cf06
1 changed files with 43 additions and 17 deletions
60
optim_nn.py
60
optim_nn.py
|
@ -1,7 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# imports
|
||||
|
||||
import numpy as np
|
||||
nf = np.float32
|
||||
nfa = lambda x: np.array(x, dtype=nf)
|
||||
|
@ -31,7 +29,7 @@ class SquaredHalved(Loss):
|
|||
|
||||
class Optimizer:
|
||||
def __init__(self, alpha=0.1):
|
||||
self.alpha = nfa(alpha)
|
||||
self.alpha = nf(alpha)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
|
@ -43,9 +41,33 @@ class Optimizer:
|
|||
def update(self, dW, W):
|
||||
W += self.compute(dW, W)
|
||||
|
||||
# the following optimizers are blatantly ripped from tiny-dnn:
|
||||
# the following optimizers are blatantly lifted from tiny-dnn:
|
||||
# https://github.com/tiny-dnn/tiny-dnn/blob/master/tiny_dnn/optimizers/optimizer.h
|
||||
|
||||
class Momentum(Optimizer):
|
||||
def __init__(self, alpha=0.01, lamb=0, mu=0.9, nesterov=False):
|
||||
self.alpha = np.asfarray(alpha) # learning rate
|
||||
self.lamb = np.asfarray(lamb) # weight decay
|
||||
self.mu = np.asfarray(mu) # momentum
|
||||
self.nesterov = bool(nesterov)
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.dWprev = None
|
||||
|
||||
def compute(self, dW, W):
|
||||
if self.dWprev is None:
|
||||
#self.dWprev = np.zeros_like(dW)
|
||||
self.dWprev = np.copy(dW)
|
||||
|
||||
V = self.mu * self.dWprev - self.alpha * (dW + W * self.lamb)
|
||||
self.dWprev = V
|
||||
if self.nesterov:
|
||||
return self.mu * V - self.alpha * (dW + W * self.lamb)
|
||||
else:
|
||||
return V
|
||||
|
||||
class Adam(Optimizer):
|
||||
def __init__(self, alpha=0.001, b1=0.9, b2=0.999, b1_t=0.9, b2_t=0.999, eps=1e-8):
|
||||
self.alpha = nf(alpha) # learning rate
|
||||
|
@ -107,7 +129,7 @@ class Layer:
|
|||
raise NotImplementedError("unimplemented", self)
|
||||
|
||||
def do_feed(self, child):
|
||||
pass
|
||||
self.children.append(child)
|
||||
|
||||
def be_fed(self, parent):
|
||||
self.parents.append(parent)
|
||||
|
@ -154,7 +176,6 @@ class Layer:
|
|||
if not child.compatible(self):
|
||||
fmt = "{} is incompatible with {}: shape mismatch: {} vs. {}"
|
||||
raise Exception(fmt.format(self, child, self.output_shape, child.input_shape))
|
||||
self.children.append(child)
|
||||
self.do_feed(child)
|
||||
child.be_fed(self)
|
||||
return child
|
||||
|
@ -375,14 +396,14 @@ class Model:
|
|||
#return lut[input_node] # meaningless value
|
||||
return self.dW
|
||||
|
||||
def load_model(self, fn):
|
||||
def load_weights(self, fn):
|
||||
# seemingly compatible with keras models at the moment
|
||||
import h5py
|
||||
f = h5py.File(fn)
|
||||
loadweights = {}
|
||||
weights = {}
|
||||
def visitor(name, obj):
|
||||
if isinstance(obj, h5py.Dataset):
|
||||
loadweights[name.split('/')[-1]] = nfa(obj[:])
|
||||
weights[name.split('/')[-1]] = nfa(obj[:])
|
||||
f.visititems(visitor)
|
||||
f.close()
|
||||
|
||||
|
@ -390,10 +411,10 @@ class Model:
|
|||
for i in range(len(denses)):
|
||||
a, b = i, i + 1
|
||||
b_name = "dense_{}".format(b)
|
||||
denses[a].coeffs = loadweights[b_name+'_W'].T
|
||||
denses[a].biases = np.expand_dims(loadweights[b_name+'_b'], -1)
|
||||
denses[a].coeffs = weights[b_name+'_W'].T
|
||||
denses[a].biases = np.expand_dims(weights[b_name+'_b'], -1)
|
||||
|
||||
def save_model(self, fn, overwrite=False):
|
||||
def save_weights(self, fn, overwrite=False):
|
||||
raise NotImplementedError("unimplemented", self)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -488,13 +509,18 @@ if __name__ == '__main__':
|
|||
if not training:
|
||||
model.load_weights(config.fn)
|
||||
|
||||
assert config.optim == 'adam'
|
||||
if config.nesterov:
|
||||
assert False, "unimplemented"
|
||||
else:
|
||||
if config.optim == 'adam':
|
||||
assert not config.nesterov, "unimplemented"
|
||||
optim = Adam()
|
||||
elif config.optim == 'sgd':
|
||||
if config.momentum != 0:
|
||||
optim = Momentum(mu=config.momentum, nesterov=config.nesterov)
|
||||
else:
|
||||
optim = Optimizer()
|
||||
else:
|
||||
raise Exception('unknown optimizer', config.optim)
|
||||
|
||||
assert config.loss == 'mse'
|
||||
assert config.loss == 'mse', 'unknown loss function'
|
||||
loss = SquaredHalved()
|
||||
|
||||
LR = config.LR
|
||||
|
|
Loading…
Reference in a new issue