optim/optim_nn.py

979 lines
30 KiB
Python
Raw Normal View History

2017-01-09 03:37:35 -08:00
#!/usr/bin/env python3
import numpy as np
2017-02-01 22:21:25 -08:00
# ugly shorthand:
2017-01-09 03:37:35 -08:00
nf = np.float32
nfa = lambda x: np.array(x, dtype=nf)
ni = np.int
nia = lambda x: np.array(x, dtype=ni)
2017-02-01 22:21:25 -08:00
# just for speed, not strictly essential:
2017-01-12 08:04:42 -08:00
from scipy.special import expit as sigmoid
2017-02-01 22:21:25 -08:00
# used for numbering layers like Keras:
2017-01-09 03:37:35 -08:00
from collections import defaultdict
2017-01-12 08:04:42 -08:00
# Initializations
# note: these are currently only implemented for 2D shapes.
def init_he_normal(size, ins, outs):
s = np.sqrt(2 / ins)
return np.random.normal(0, s, size=size)
def init_he_uniform(size, ins, outs):
s = np.sqrt(6 / ins)
return np.random.uniform(-s, s, size=size)
2017-01-09 03:37:35 -08:00
# Loss functions
class Loss:
def mean(self, r):
return np.average(self.f(r))
def dmean(self, r):
d = self.df(r)
return d / len(d)
2017-01-10 13:21:26 -08:00
class Squared(Loss):
def f(self, r):
return np.square(r)
def df(self, r):
return 2 * r
2017-01-09 03:37:35 -08:00
class SquaredHalved(Loss):
def f(self, r):
return np.square(r) / 2
def df(self, r):
return r
2017-01-12 08:04:42 -08:00
class SomethingElse(Loss):
2017-02-01 22:21:25 -08:00
# generalizes Absolute and SquaredHalved (|dx| = 1)
2017-01-12 08:04:42 -08:00
# plot: https://www.desmos.com/calculator/fagjg9vuz7
def __init__(self, a=4/3):
assert 1 <= a <= 2, "parameter out of range"
self.a = nf(a / 2)
self.b = nf(2 / a)
self.c = nf(2 / a - 1)
def f(self, r):
return self.a * np.abs(r)**self.b
def df(self, r):
return np.sign(r) * np.abs(r)**self.c
2017-01-09 03:37:35 -08:00
# Optimizers
class Optimizer:
def __init__(self, alpha=0.1):
2017-01-10 04:27:49 -08:00
self.alpha = nf(alpha)
2017-01-09 03:37:35 -08:00
self.reset()
def reset(self):
pass
def compute(self, dW, W):
return -self.alpha * dW
def update(self, dW, W):
W += self.compute(dW, W)
2017-01-10 04:27:49 -08:00
# the following optimizers are blatantly lifted from tiny-dnn:
2017-01-09 03:37:35 -08:00
# https://github.com/tiny-dnn/tiny-dnn/blob/master/tiny_dnn/optimizers/optimizer.h
2017-01-10 04:27:49 -08:00
class Momentum(Optimizer):
def __init__(self, alpha=0.01, lamb=0, mu=0.9, nesterov=False):
self.alpha = np.asfarray(alpha) # learning rate
self.lamb = np.asfarray(lamb) # weight decay
self.mu = np.asfarray(mu) # momentum
self.nesterov = bool(nesterov)
self.reset()
def reset(self):
self.dWprev = None
def compute(self, dW, W):
if self.dWprev is None:
#self.dWprev = np.zeros_like(dW)
self.dWprev = np.copy(dW)
V = self.mu * self.dWprev - self.alpha * (dW + W * self.lamb)
2017-01-12 14:45:07 -08:00
self.dWprev[:] = V
2017-02-01 22:21:25 -08:00
if self.nesterov: # TODO: is this correct? looks weird
2017-01-10 04:27:49 -08:00
return self.mu * V - self.alpha * (dW + W * self.lamb)
else:
return V
2017-01-09 03:37:35 -08:00
class Adam(Optimizer):
def __init__(self, alpha=0.001, b1=0.9, b2=0.999, b1_t=0.9, b2_t=0.999, eps=1e-8):
self.alpha = nf(alpha) # learning rate
self.b1 = nf(b1) # decay term
self.b2 = nf(b2) # decay term
self.b1_t_default = nf(b1_t) # decay term power t
self.b2_t_default = nf(b2_t) # decay term power t
self.eps = nf(eps)
self.reset()
def reset(self):
self.mt = None
self.vt = None
self.b1_t = self.b1_t_default
self.b2_t = self.b2_t_default
def compute(self, dW, W):
if self.mt is None:
self.mt = np.zeros_like(W)
if self.vt is None:
self.vt = np.zeros_like(W)
# decay
self.b1_t *= self.b1
self.b2_t *= self.b2
2017-01-12 08:04:42 -08:00
self.mt[:] = self.b1 * self.mt + (1 - self.b1) * dW
self.vt[:] = self.b2 * self.vt + (1 - self.b2) * dW * dW
2017-01-09 03:37:35 -08:00
return -self.alpha * (self.mt / (1 - self.b1_t)) \
/ np.sqrt((self.vt / (1 - self.b2_t)) + self.eps)
# Abstract Layers
_layer_counters = defaultdict(lambda: 0)
class Layer:
def __init__(self):
self.parents = []
self.children = []
self.input_shape = None
self.output_shape = None
kind = self.__class__.__name__
global _layer_counters
_layer_counters[kind] += 1
self.name = "{}_{}".format(kind, _layer_counters[kind])
self.size = None # total weight count (if any)
2017-01-11 02:20:49 -08:00
self.unsafe = False # disables assertions for better performance
2017-01-09 03:37:35 -08:00
def __str__(self):
return self.name
# methods we might want to override:
def F(self, X):
raise NotImplementedError("unimplemented", self)
def dF(self, dY):
raise NotImplementedError("unimplemented", self)
def do_feed(self, child):
2017-01-10 04:27:49 -08:00
self.children.append(child)
2017-01-09 03:37:35 -08:00
def be_fed(self, parent):
self.parents.append(parent)
def make_shape(self, shape):
2017-01-10 19:06:42 -08:00
if not self.unsafe:
assert shape is not None
2017-01-09 03:37:35 -08:00
if self.output_shape is None:
self.output_shape = shape
return shape
# TODO: rename this multi and B crap to something actually relevant.
def multi(self, B):
2017-01-10 19:06:42 -08:00
if not self.unsafe:
assert len(B) == 1, self
2017-01-10 13:21:26 -08:00
return self.F(B[0])
2017-01-09 03:37:35 -08:00
def dmulti(self, dB):
if len(dB) == 1:
2017-01-10 13:21:26 -08:00
return self.dF(dB[0])
2017-01-09 03:37:35 -08:00
else:
dX = None
for dY in dB:
if dX is None:
2017-01-10 13:21:26 -08:00
dX = self.dF(dY)
2017-01-09 03:37:35 -08:00
else:
2017-01-10 13:21:26 -08:00
dX += self.dF(dY)
2017-01-09 03:37:35 -08:00
return dX
# general utility methods:
def compatible(self, parent):
if self.input_shape is None:
# inherit shape from output
shape = self.make_shape(parent.output_shape)
if shape is None:
return False
self.input_shape = shape
if np.all(self.input_shape == parent.output_shape):
return True
else:
return False
def feed(self, child):
if not child.compatible(self):
fmt = "{} is incompatible with {}: shape mismatch: {} vs. {}"
raise Exception(fmt.format(self, child, self.output_shape, child.input_shape))
self.do_feed(child)
child.be_fed(self)
return child
def validate_input(self, X):
2017-01-09 04:35:28 -08:00
assert X.shape[1:] == self.input_shape, (str(self), X.shape[1:], self.input_shape)
2017-01-09 03:37:35 -08:00
def validate_output(self, Y):
2017-01-09 04:35:28 -08:00
assert Y.shape[1:] == self.output_shape, (str(self), Y.shape[1:], self.output_shape)
2017-01-09 03:37:35 -08:00
def forward(self, lut):
2017-01-10 19:06:42 -08:00
if not self.unsafe:
assert len(self.parents) > 0, self
2017-01-09 03:37:35 -08:00
B = []
for parent in self.parents:
# TODO: skip over irrelevant nodes (if any)
X = lut[parent]
2017-01-10 19:06:42 -08:00
if not self.unsafe:
self.validate_input(X)
2017-01-09 03:37:35 -08:00
B.append(X)
Y = self.multi(B)
2017-01-10 19:06:42 -08:00
if not self.unsafe:
self.validate_output(Y)
2017-01-09 03:37:35 -08:00
return Y
def backward(self, lut):
2017-01-10 19:06:42 -08:00
if not self.unsafe:
assert len(self.children) > 0, self
2017-01-09 03:37:35 -08:00
dB = []
for child in self.children:
# TODO: skip over irrelevant nodes (if any)
dY = lut[child]
2017-01-10 19:06:42 -08:00
if not self.unsafe:
self.validate_output(dY)
2017-01-09 03:37:35 -08:00
dB.append(dY)
dX = self.dmulti(dB)
2017-01-10 19:06:42 -08:00
if not self.unsafe:
self.validate_input(dX)
2017-01-09 03:37:35 -08:00
return dX
# Final Layers
class Sum(Layer):
def multi(self, B):
return np.sum(B, axis=0)
def dmulti(self, dB):
#assert len(dB) == 1, "unimplemented"
return dB[0] # TODO: does this always work?
class Input(Layer):
def __init__(self, shape):
assert shape is not None
super().__init__()
self.shape = tuple(shape)
self.input_shape = self.shape
self.output_shape = self.shape
def F(self, X):
return X
def dF(self, dY):
2017-01-09 04:35:28 -08:00
#self.dY = dY
2017-01-09 03:37:35 -08:00
return np.zeros_like(dY)
class Affine(Layer):
def __init__(self, a=1, b=0):
super().__init__()
self.a = nf(a)
self.b = nf(b)
def F(self, X):
return self.a * X + self.b
def dF(self, dY):
return dY * self.a
2017-01-09 22:19:28 -08:00
class Sigmoid(Layer): # aka Logistic
def F(self, X):
from scipy.special import expit as sigmoid
self.sig = sigmoid(X)
return X * self.sig
def dF(self, dY):
return dY * self.sig * (1 - self.sig)
class Tanh(Layer):
def F(self, X):
self.sig = np.tanh(X)
return X * self.sig
def dF(self, dY):
return dY * (1 - self.sig * self.sig)
2017-01-09 03:37:35 -08:00
class Relu(Layer):
def F(self, X):
self.cond = X >= 0
return np.where(self.cond, X, 0)
def dF(self, dY):
return np.where(self.cond, dY, 0)
2017-01-11 02:20:49 -08:00
class Elu(Layer):
# paper: https://arxiv.org/abs/1511.07289
def __init__(self, alpha=1):
super().__init__()
self.alpha = nf(alpha)
def F(self, X):
self.cond = X >= 0
self.neg = np.exp(X) - 1
return np.where(self.cond, X, self.neg)
def dF(self, dY):
return dY * np.where(self.cond, 1, self.neg + 1)
2017-01-09 04:35:28 -08:00
class GeluApprox(Layer):
2017-01-09 22:19:28 -08:00
# paper: https://arxiv.org/abs/1606.08415
# plot: https://www.desmos.com/calculator/ydzgtccsld
2017-01-09 04:35:28 -08:00
def F(self, X):
self.a = 1.704 * X
self.sig = sigmoid(self.a)
return X * self.sig
def dF(self, dY):
return dY * self.sig * (1 + self.a * (1 - self.sig))
2017-01-09 03:37:35 -08:00
class Dense(Layer):
2017-01-13 03:29:19 -08:00
def __init__(self, dim, init=init_he_uniform):
2017-01-09 03:37:35 -08:00
super().__init__()
self.dim = ni(dim)
self.output_shape = (dim,)
2017-01-13 03:29:19 -08:00
self.weight_init = init
2017-02-01 22:03:12 -08:00
self.size = None
def make_shape(self, shape):
super().make_shape(shape)
if len(shape) != 1:
return False
self.nW = self.dim * shape[0]
self.nb = self.dim
self.size = self.nW + self.nb
return shape
2017-01-09 03:37:35 -08:00
def init(self, W, dW):
ins, outs = self.input_shape[0], self.output_shape[0]
self.W = W
self.dW = dW
2017-01-10 13:21:26 -08:00
self.coeffs = self.W[:self.nW].reshape(ins, outs)
self.biases = self.W[self.nW:].reshape(1, outs)
self.dcoeffs = self.dW[:self.nW].reshape(ins, outs)
self.dbiases = self.dW[self.nW:].reshape(1, outs)
2017-01-09 03:37:35 -08:00
2017-01-13 03:29:19 -08:00
self.coeffs.flat = self.weight_init(self.nW, ins, outs)
2017-01-09 03:37:35 -08:00
self.biases.flat = 0
2017-02-01 22:03:12 -08:00
def F(self, X):
self.X = X
Y = X.dot(self.coeffs) \
+ self.biases
return Y
def dF(self, dY):
dX = dY.dot(self.coeffs.T)
self.dcoeffs[:] = self.X.T.dot(dY)
self.dbiases[:] = dY.sum(0, keepdims=True)
return dX
class DenseOneLess(Dense):
def init(self, W, dW):
super().init(W, dW)
ins, outs = self.input_shape[0], self.output_shape[0]
assert ins == outs, (ins, outs)
2017-01-09 03:37:35 -08:00
def F(self, X):
2017-02-01 22:03:12 -08:00
np.fill_diagonal(self.coeffs, 0)
2017-01-09 03:37:35 -08:00
self.X = X
2017-01-10 13:21:26 -08:00
Y = X.dot(self.coeffs) \
2017-01-09 03:37:35 -08:00
+ self.biases
return Y
def dF(self, dY):
2017-01-10 13:21:26 -08:00
dX = dY.dot(self.coeffs.T)
self.dcoeffs[:] = self.X.T.dot(dY)
2017-02-01 22:03:12 -08:00
self.dbiases[:] = dY.sum(0, keepdims=True)
np.fill_diagonal(self.dcoeffs, 0)
return dX
class LayerNorm(Layer): # TODO: inherit Affine instead?
def __init__(self, eps=1e-3, axis=-1):
super().__init__()
self.eps = nf(eps)
self.axis = int(axis)
def F(self, X):
self.center = X - np.mean(X, axis=self.axis, keepdims=True)
#self.var = np.var(X, axis=self.axis, keepdims=True) + self.eps
self.var = np.mean(np.square(self.center), axis=self.axis, keepdims=True) + self.eps
self.std = np.sqrt(self.var) + self.eps
Y = self.center / self.std
return Y
def dF(self, dY):
length = self.input_shape[self.axis]
dstd = dY * (-self.center / self.var)
dvar = dstd * (0.5 / self.std)
dcenter2 = dvar * (1 / length)
dcenter = dY * (1 / self.std)
dcenter += dcenter2 * (2 * self.center)
dX = dcenter - dcenter / length
2017-01-09 03:37:35 -08:00
return dX
# Model
class Model:
2017-01-10 19:06:42 -08:00
def __init__(self, x, y, unsafe=False):
2017-01-09 03:37:35 -08:00
assert isinstance(x, Layer), x
assert isinstance(y, Layer), y
self.x = x
self.y = y
self.ordered_nodes = self.traverse([], self.y)
self.make_weights()
2017-01-10 19:06:42 -08:00
for node in self.ordered_nodes:
node.unsafe = unsafe
2017-01-09 03:37:35 -08:00
def make_weights(self):
self.param_count = 0
for node in self.ordered_nodes:
if node.size is not None:
self.param_count += node.size
self.W = np.zeros(self.param_count, dtype=nf)
self.dW = np.zeros(self.param_count, dtype=nf)
offset = 0
for node in self.ordered_nodes:
if node.size is not None:
end = offset + node.size
node.init(self.W[offset:end], self.dW[offset:end])
offset += node.size
def traverse(self, nodes, node):
2017-01-12 14:45:07 -08:00
if node == self.x:
2017-01-09 03:37:35 -08:00
return [node]
for parent in node.parents:
if parent not in nodes:
new_nodes = self.traverse(nodes, parent)
for new_node in new_nodes:
if new_node not in nodes:
nodes.append(new_node)
if nodes:
nodes.append(node)
return nodes
def forward(self, X):
lut = dict()
input_node = self.ordered_nodes[0]
output_node = self.ordered_nodes[-1]
lut[input_node] = input_node.multi(np.expand_dims(X, 0))
for node in self.ordered_nodes[1:]:
lut[node] = node.forward(lut)
return lut[output_node]
def backward(self, error):
lut = dict()
input_node = self.ordered_nodes[0]
output_node = self.ordered_nodes[-1]
lut[output_node] = output_node.dmulti(np.expand_dims(error, 0))
for node in reversed(self.ordered_nodes[:-1]):
lut[node] = node.backward(lut)
#return lut[input_node] # meaningless value
return self.dW
2017-01-10 04:27:49 -08:00
def load_weights(self, fn):
2017-01-13 03:29:19 -08:00
# seemingly compatible with keras' Dense layers.
# ignores any non-Dense layer types.
# TODO: assert file actually exists
2017-01-09 03:37:35 -08:00
import h5py
f = h5py.File(fn)
2017-01-10 04:27:49 -08:00
weights = {}
2017-01-09 03:37:35 -08:00
def visitor(name, obj):
if isinstance(obj, h5py.Dataset):
2017-01-10 04:27:49 -08:00
weights[name.split('/')[-1]] = nfa(obj[:])
2017-01-09 03:37:35 -08:00
f.visititems(visitor)
f.close()
denses = [node for node in self.ordered_nodes if isinstance(node, Dense)]
for i in range(len(denses)):
a, b = i, i + 1
b_name = "dense_{}".format(b)
2017-01-12 08:04:42 -08:00
# TODO: write a Dense method instead of assigning directly
2017-01-13 03:29:19 -08:00
denses[a].coeffs[:] = weights[b_name+'_W']
denses[a].biases[:] = np.expand_dims(weights[b_name+'_b'], 0)
2017-01-09 03:37:35 -08:00
2017-01-10 04:27:49 -08:00
def save_weights(self, fn, overwrite=False):
2017-01-13 03:29:19 -08:00
import h5py
f = h5py.File(fn, 'w')
denses = [node for node in self.ordered_nodes if isinstance(node, Dense)]
for i in range(len(denses)):
a, b = i, i + 1
b_name = "dense_{}".format(b)
# TODO: write a Dense method instead of assigning directly
grp = f.create_group(b_name)
data = grp.create_dataset(b_name+'_W', denses[a].coeffs.shape, dtype=nf)
data[:] = denses[a].coeffs
data = grp.create_dataset(b_name+'_b', denses[a].biases.shape, dtype=nf)
data[:] = denses[a].biases
f.close()
2017-01-09 03:37:35 -08:00
2017-02-01 22:21:25 -08:00
class Ritual: # i'm just making up names at this point
def __init__(self, learner=None, loss=None, mloss=None):
self.learner = learner if learner is not None else Learner(Optimizer())
2017-01-12 14:45:07 -08:00
self.loss = loss if loss is not None else Squared()
self.mloss = mloss if mloss is not None else loss
2017-02-01 22:21:25 -08:00
def reset(self):
self.learner.reset(optim=True)
2017-01-12 14:45:07 -08:00
def measure(self, residual):
return self.mloss.mean(residual)
def derive(self, residual):
return self.loss.dmean(residual)
def train_batched(self, model, inputs, outputs, batch_size, return_losses=False):
cumsum_loss = 0
batch_count = inputs.shape[0] // batch_size
losses = []
for b in range(batch_count):
bi = b * batch_size
batch_inputs = inputs[ bi:bi+batch_size]
batch_outputs = outputs[bi:bi+batch_size]
2017-02-02 14:25:40 -08:00
if self.learner.per_batch:
self.learner.batch(b / batch_count)
2017-01-12 14:45:07 -08:00
predicted = model.forward(batch_inputs)
residual = predicted - batch_outputs
model.backward(self.derive(residual))
2017-02-01 22:21:25 -08:00
self.learner.optim.update(model.dW, model.W)
2017-01-12 14:45:07 -08:00
batch_loss = self.measure(residual)
2017-02-01 22:03:12 -08:00
if np.isnan(batch_loss):
raise Exception("nan")
2017-01-12 14:45:07 -08:00
cumsum_loss += batch_loss
if return_losses:
losses.append(batch_loss)
avg_loss = cumsum_loss / batch_count
if return_losses:
return avg_loss, losses
else:
return avg_loss
2017-02-01 22:21:25 -08:00
class Learner:
2017-02-02 14:25:40 -08:00
per_batch = False
2017-02-01 22:21:25 -08:00
def __init__(self, optim, epochs=100, rate=None):
assert isinstance(optim, Optimizer)
self.optim = optim
self.start_rate = optim.alpha if rate is None else float(rate)
self.epochs = int(epochs)
self.reset()
def reset(self, optim=False):
self.started = False
self.epoch = 0
if optim:
self.optim.reset()
@property
def epoch(self):
return self._epoch
@epoch.setter
def epoch(self, new_epoch):
self._epoch = int(new_epoch)
self.rate = self.rate_at(self._epoch)
@property
def rate(self):
return self.optim.alpha
@rate.setter
def rate(self, new_rate):
self.optim.alpha = new_rate
def rate_at(self, epoch):
return self.start_rate
def next(self):
2017-02-02 14:25:40 -08:00
# prepares the next epoch. returns whether or not to continue training.
2017-02-01 22:21:25 -08:00
if self.epoch + 1 >= self.epochs:
return False
if self.started:
self.epoch += 1
else:
self.started = True
self.epoch = self.epoch # poke property setter just in case
return True
2017-02-02 14:25:40 -08:00
def batch(self, progress): # TODO: rename
# interpolates rates between epochs.
# unlike epochs, we do not store batch number as a state.
# i.e. calling next() will not respect progress.
assert 0 <= progress <= 1
self.rate = self.rate_at(self._epoch + progress)
2017-02-01 22:21:25 -08:00
@property
def final_rate(self):
return self.rate_at(self.epochs - 1)
class AnnealingLearner(Learner):
def __init__(self, optim, epochs=100, rate=None, halve_every=10):
self.halve_every = float(halve_every)
self.anneal = 0.5**(1/self.halve_every)
super().__init__(optim, epochs, rate)
def rate_at(self, epoch):
2017-02-02 14:25:40 -08:00
return self.start_rate * self.anneal**epoch
2017-02-01 22:21:25 -08:00
class DumbLearner(AnnealingLearner):
# this is my own awful contraption. it's not really "SGD with restarts".
def __init__(self, optim, epochs=100, rate=None, halve_every=10, restarts=0, restart_advance=20, callback=None):
self.restart_epochs = int(epochs)
self.restarts = int(restarts)
self.restart_advance = float(restart_advance)
self.restart_callback = callback
epochs = self.restart_epochs * (self.restarts + 1)
super().__init__(optim, epochs, rate, halve_every)
def rate_at(self, epoch):
sub_epoch = epoch % self.restart_epochs
restart = epoch // self.restart_epochs
return super().rate_at(sub_epoch) * (self.anneal**self.restart_advance)**restart
def next(self):
if not super().next():
return False
sub_epoch = self.epoch % self.restart_epochs
restart = self.epoch // self.restart_epochs
if restart > 0 and sub_epoch == 0:
if self.restart_callback is not None:
self.restart_callback(restart)
return True
2017-02-02 14:25:40 -08:00
def cosmod(x):
# plot: https://www.desmos.com/calculator/hlgqmyswy2
return (1 + np.cos((x % 1) * np.pi)) / 2
class SGDR(Learner):
# Stochastic Gradient Descent with Restarts
# paper: https://arxiv.org/abs/1608.03983
# NOTE: this is not a complete implementation.
per_batch = True
def __init__(self, optim, epochs=100, rate=None, restarts=0, restart_decay=0.5, callback=None):
self.restart_epochs = int(epochs)
self.decay = float(restart_decay)
self.restarts = int(restarts)
self.restart_callback = callback
epochs = self.restart_epochs * (self.restarts + 1)
super().__init__(optim, epochs, rate)
def rate_at(self, epoch):
sub_epoch = epoch % self.restart_epochs
x = sub_epoch / self.restart_epochs
restart = epoch // self.restart_epochs
return self.start_rate * self.decay**restart * cosmod(x)
def next(self):
if not super().next():
return False
sub_epoch = self.epoch % self.restart_epochs
restart = self.epoch // self.restart_epochs
if restart > 0 and sub_epoch == 0:
if self.restart_callback is not None:
self.restart_callback(restart)
return True
2017-01-13 03:29:19 -08:00
def multiresnet(x, width, depth, block=2, multi=1,
activation=Relu, style='batchless',
init=init_he_normal):
2017-01-12 08:04:42 -08:00
y = x
last_size = x.output_shape[0]
2017-02-01 22:03:12 -08:00
FC = lambda size: Dense(size, init)
#FC = lambda size: DenseOneLess(size, init)
2017-01-12 08:04:42 -08:00
for d in range(depth):
size = width
if last_size != size:
2017-01-13 03:29:19 -08:00
y = y.feed(Dense(size, init))
2017-01-12 08:04:42 -08:00
if style == 'batchless':
skip = y
merger = Sum()
skip.feed(merger)
z_start = skip.feed(activation())
for i in range(multi):
z = z_start
for i in range(block):
if i > 0:
z = z.feed(activation())
2017-02-01 22:03:12 -08:00
z = z.feed(FC(size))
2017-01-12 08:04:42 -08:00
z.feed(merger)
y = merger
elif style == 'onelesssum':
is_last = d + 1 == depth
needs_sum = not is_last or multi > 1
skip = y
if needs_sum:
merger = Sum()
if not is_last:
skip.feed(merger)
z_start = skip.feed(activation())
for i in range(multi):
z = z_start
for i in range(block):
if i > 0:
z = z.feed(activation())
2017-02-01 22:03:12 -08:00
z = z.feed(FC(size))
2017-01-12 08:04:42 -08:00
if needs_sum:
z.feed(merger)
if needs_sum:
y = merger
else:
y = z
else:
raise Exception('unknown resnet style', style)
last_size = size
return y
2017-01-13 03:29:19 -08:00
inits = dict(he_normal=init_he_normal, he_uniform=init_he_uniform)
activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, elu=Elu, gelu=GeluApprox)
2017-01-12 14:45:07 -08:00
def run(program, args=[]):
2017-01-10 14:33:12 -08:00
import sys
lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs)
def log(left, right):
lament("{:>20}: {}".format(left, right))
2017-01-09 03:37:35 -08:00
# Config
from dotmap import DotMap
config = DotMap(
2017-01-13 03:29:19 -08:00
fn_load = None,
fn_save = 'optim_nn.h5',
2017-01-12 14:45:07 -08:00
log_fn = 'losses.npz',
2017-01-09 03:37:35 -08:00
2017-01-09 22:19:28 -08:00
# multi-residual network parameters
2017-02-02 15:26:23 -08:00
res_width = 28,
res_depth = 2,
res_block = 3, # normally 2 for plain resnet
res_multi = 2, # normally 1 for plain resnet
2017-01-09 22:19:28 -08:00
2017-01-11 02:20:49 -08:00
# style of resnet (order of layers, which layers, etc.)
2017-01-12 08:04:42 -08:00
parallel_style = 'onelesssum',
2017-01-11 02:20:49 -08:00
activation = 'gelu',
2017-01-09 03:37:35 -08:00
optim = 'adam',
nesterov = False, # only used with SGD or Adam
momentum = 0.33, # only used with SGD
2017-01-09 22:19:28 -08:00
2017-02-02 14:25:40 -08:00
# learning parameters
learner = 'SGDR',
2017-01-13 03:29:19 -08:00
learn = 1e-2,
2017-01-11 02:20:49 -08:00
epochs = 24,
restarts = 2,
2017-02-02 15:26:23 -08:00
learn_decay = 0.25, # only used with SGDR
learn_halve_every = 16, # unused with SGDR
learn_restart_advance = 16, # unused with SGDR
2017-01-09 22:19:28 -08:00
# misc
batch_size = 64,
2017-01-09 03:37:35 -08:00
init = 'he_normal',
2017-02-01 22:03:12 -08:00
loss = SomethingElse(),
2017-01-12 08:04:42 -08:00
mloss = 'mse',
2017-02-02 15:26:23 -08:00
restart_optim = False, # restarts also reset internal state of optimizer
2017-01-12 08:04:42 -08:00
unsafe = True, # aka gotta go fast mode
train_compare = None,
valid_compare = 0.0000946,
2017-01-09 03:37:35 -08:00
)
2017-01-11 02:20:49 -08:00
config.pprint()
2017-01-09 03:37:35 -08:00
# toy CIE-2000 data
2017-02-01 22:03:12 -08:00
from ml.cie_mlp_data import rgbcompare, input_samples, output_samples, \
inputs, outputs, valid_inputs, valid_outputs, \
x_scale, y_scale
2017-01-09 03:37:35 -08:00
# Our Test Model
2017-01-13 03:29:19 -08:00
init = inits[config.init]
activation = activations[config.activation]
2017-01-09 03:37:35 -08:00
x = Input(shape=(input_samples,))
y = x
2017-01-12 08:04:42 -08:00
y = multiresnet(y,
config.res_width, config.res_depth,
config.res_block, config.res_multi,
2017-02-01 22:03:12 -08:00
activation=activation, init=init,
style=config.parallel_style)
2017-01-12 08:04:42 -08:00
if y.output_shape[0] != output_samples:
2017-01-13 03:29:19 -08:00
y = y.feed(Dense(output_samples, init))
2017-01-09 03:37:35 -08:00
2017-01-11 02:20:49 -08:00
model = Model(x, y, unsafe=config.unsafe)
2017-01-09 03:37:35 -08:00
2017-02-01 22:03:12 -08:00
if 0:
node_names = ' '.join([str(node) for node in model.ordered_nodes])
log('{} nodes'.format(len(model.ordered_nodes)), node_names)
else:
for node in model.ordered_nodes:
children = [str(n) for n in node.children]
if len(children) > 0:
sep = '->'
print(str(node)+sep+('\n'+str(node)+sep).join(children))
2017-01-10 14:33:12 -08:00
log('parameters', model.param_count)
2017-02-01 22:21:25 -08:00
#
2017-01-09 03:37:35 -08:00
training = config.epochs > 0 and config.restarts >= 0
2017-01-13 03:29:19 -08:00
if config.fn_load is not None:
log('loading weights', config.fn_load)
model.load_weights(config.fn_load)
2017-01-09 03:37:35 -08:00
2017-02-01 22:21:25 -08:00
#
2017-01-10 04:27:49 -08:00
if config.optim == 'adam':
assert not config.nesterov, "unimplemented"
2017-01-09 03:37:35 -08:00
optim = Adam()
2017-01-10 04:27:49 -08:00
elif config.optim == 'sgd':
if config.momentum != 0:
optim = Momentum(mu=config.momentum, nesterov=config.nesterov)
else:
optim = Optimizer()
else:
raise Exception('unknown optimizer', config.optim)
2017-01-09 03:37:35 -08:00
2017-02-01 22:21:25 -08:00
def rscb(restart):
measure_error() # declared later...
log("restarting", restart)
if config.restart_optim:
optim.reset()
#
2017-02-02 14:25:40 -08:00
if config.learner == 'SGDR':
2017-02-02 15:26:23 -08:00
#decay = 0.5**(1/(config.epochs / config.learn_halve_every))
2017-02-02 14:25:40 -08:00
learner = SGDR(optim, epochs=config.epochs, rate=config.learn,
2017-02-02 15:26:23 -08:00
restart_decay=config.learn_decay, restarts=config.restarts,
2017-02-02 14:25:40 -08:00
callback=rscb)
# final learning rate isn't of interest here; it's gonna be close to 0.
else:
learner = DumbLearner(optim, epochs=config.epochs, rate=config.learn,
halve_every=config.learn_halve_every,
restarts=config.restarts, restart_advance=config.learn_restart_advance,
callback=rscb)
log("final learning rate", "{:10.8f}".format(learner.final_rate))
2017-02-01 22:21:25 -08:00
#
2017-01-12 08:04:42 -08:00
def lookup_loss(maybe_name):
if isinstance(maybe_name, Loss):
return maybe_name
elif maybe_name == 'mse':
return Squared()
elif maybe_name == 'mshe': # mushy
return SquaredHalved()
raise Exception('unknown objective', maybe_name)
loss = lookup_loss(config.loss)
mloss = lookup_loss(config.mloss) if config.mloss else loss
2017-01-09 03:37:35 -08:00
2017-02-01 22:21:25 -08:00
ritual = Ritual(learner=learner, loss=loss, mloss=mloss)
2017-01-10 14:33:12 -08:00
2017-01-09 03:37:35 -08:00
# Training
2017-01-12 08:04:42 -08:00
batch_losses = []
train_losses = []
valid_losses = []
def measure_error():
def print_error(name, inputs, outputs, comparison=None):
predicted = model.forward(inputs)
residual = predicted - outputs
2017-01-12 14:45:07 -08:00
err = ritual.measure(residual)
2017-01-12 08:04:42 -08:00
log(name + " loss", "{:11.7f}".format(err))
if comparison:
log("improvement", "{:+7.2f}%".format((comparison / err - 1) * 100))
return err
train_err = print_error("train",
2017-01-12 14:45:07 -08:00
inputs / x_scale, outputs / y_scale,
config.train_compare)
2017-01-12 08:04:42 -08:00
valid_err = print_error("valid",
2017-01-12 14:45:07 -08:00
valid_inputs / x_scale, valid_outputs / y_scale,
config.valid_compare)
2017-01-12 08:04:42 -08:00
train_losses.append(train_err)
valid_losses.append(valid_err)
2017-01-09 03:37:35 -08:00
2017-02-01 22:21:25 -08:00
measure_error()
2017-01-09 03:37:35 -08:00
2017-02-01 22:21:25 -08:00
assert inputs.shape[0] % config.batch_size == 0, \
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
while learner.next():
indices = np.arange(inputs.shape[0])
np.random.shuffle(indices)
shuffled_inputs = inputs[indices] / x_scale
shuffled_outputs = outputs[indices] / y_scale
avg_loss, losses = ritual.train_batched(model,
shuffled_inputs, shuffled_outputs,
config.batch_size,
return_losses=True)
batch_losses += losses
#log("learning rate", "{:10.8f}".format(learner.rate))
#log("average loss", "{:11.7f}".format(avg_loss))
fmt = "epoch {:4.0f}, rate {:10.8f}, loss {:11.7f}"
log("info", fmt.format(learner.epoch + 1, learner.rate, avg_loss))
2017-01-09 03:37:35 -08:00
2017-01-12 08:04:42 -08:00
measure_error()
2017-01-09 03:37:35 -08:00
2017-01-13 03:29:19 -08:00
if config.fn_save is not None:
log('saving weights', config.fn_save)
model.save_weights(config.fn_save, overwrite=True)
2017-01-09 03:37:35 -08:00
# Evaluation
2017-02-01 22:21:25 -08:00
# this is just an example/test of how to predict a single output;
# it doesn't measure the quality of the network or anything.
2017-01-09 03:37:35 -08:00
a = (192, 128, 64)
b = (64, 128, 192)
X = np.expand_dims(np.hstack((a, b)), 0) / x_scale
P = model.forward(X) * y_scale
2017-01-10 14:33:12 -08:00
log("truth", rgbcompare(a, b))
log("network", np.squeeze(P))
2017-01-12 08:04:42 -08:00
if config.log_fn is not None:
np.savez_compressed(config.log_fn,
batch_losses=nfa(batch_losses),
train_losses=nfa(train_losses),
valid_losses=nfa(valid_losses))
2017-01-12 14:45:07 -08:00
return 0
if __name__ == '__main__':
import sys
sys.exit(run(sys.argv[0], sys.argv[1:]))