.
This commit is contained in:
parent
315213be6d
commit
02e03ad85e
1 changed files with 145 additions and 58 deletions
203
optim_nn.py
203
optim_nn.py
|
@ -6,8 +6,22 @@ nfa = lambda x: np.array(x, dtype=nf)
|
||||||
ni = np.int
|
ni = np.int
|
||||||
nia = lambda x: np.array(x, dtype=ni)
|
nia = lambda x: np.array(x, dtype=ni)
|
||||||
|
|
||||||
|
from scipy.special import expit as sigmoid
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
# Initializations
|
||||||
|
|
||||||
|
# note: these are currently only implemented for 2D shapes.
|
||||||
|
|
||||||
|
def init_he_normal(size, ins, outs):
|
||||||
|
s = np.sqrt(2 / ins)
|
||||||
|
return np.random.normal(0, s, size=size)
|
||||||
|
|
||||||
|
def init_he_uniform(size, ins, outs):
|
||||||
|
s = np.sqrt(6 / ins)
|
||||||
|
return np.random.uniform(-s, s, size=size)
|
||||||
|
|
||||||
# Loss functions
|
# Loss functions
|
||||||
|
|
||||||
class Loss:
|
class Loss:
|
||||||
|
@ -32,6 +46,21 @@ class SquaredHalved(Loss):
|
||||||
def df(self, r):
|
def df(self, r):
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
class SomethingElse(Loss):
|
||||||
|
# generalizes Absolute and SquaredHalved
|
||||||
|
# plot: https://www.desmos.com/calculator/fagjg9vuz7
|
||||||
|
def __init__(self, a=4/3):
|
||||||
|
assert 1 <= a <= 2, "parameter out of range"
|
||||||
|
self.a = nf(a / 2)
|
||||||
|
self.b = nf(2 / a)
|
||||||
|
self.c = nf(2 / a - 1)
|
||||||
|
|
||||||
|
def f(self, r):
|
||||||
|
return self.a * np.abs(r)**self.b
|
||||||
|
|
||||||
|
def df(self, r):
|
||||||
|
return np.sign(r) * np.abs(r)**self.c
|
||||||
|
|
||||||
# Optimizers
|
# Optimizers
|
||||||
|
|
||||||
class Optimizer:
|
class Optimizer:
|
||||||
|
@ -102,8 +131,8 @@ class Adam(Optimizer):
|
||||||
self.b1_t *= self.b1
|
self.b1_t *= self.b1
|
||||||
self.b2_t *= self.b2
|
self.b2_t *= self.b2
|
||||||
|
|
||||||
self.mt = self.b1 * self.mt + (1 - self.b1) * dW
|
self.mt[:] = self.b1 * self.mt + (1 - self.b1) * dW
|
||||||
self.vt = self.b2 * self.vt + (1 - self.b2) * dW * dW
|
self.vt[:] = self.b2 * self.vt + (1 - self.b2) * dW * dW
|
||||||
|
|
||||||
return -self.alpha * (self.mt / (1 - self.b1_t)) \
|
return -self.alpha * (self.mt / (1 - self.b1_t)) \
|
||||||
/ np.sqrt((self.vt / (1 - self.b2_t)) + self.eps)
|
/ np.sqrt((self.vt / (1 - self.b2_t)) + self.eps)
|
||||||
|
@ -306,7 +335,6 @@ class GeluApprox(Layer):
|
||||||
# paper: https://arxiv.org/abs/1606.08415
|
# paper: https://arxiv.org/abs/1606.08415
|
||||||
# plot: https://www.desmos.com/calculator/ydzgtccsld
|
# plot: https://www.desmos.com/calculator/ydzgtccsld
|
||||||
def F(self, X):
|
def F(self, X):
|
||||||
from scipy.special import expit as sigmoid
|
|
||||||
self.a = 1.704 * X
|
self.a = 1.704 * X
|
||||||
self.sig = sigmoid(self.a)
|
self.sig = sigmoid(self.a)
|
||||||
return X * self.sig
|
return X * self.sig
|
||||||
|
@ -331,9 +359,7 @@ class Dense(Layer):
|
||||||
self.dcoeffs = self.dW[:self.nW].reshape(ins, outs)
|
self.dcoeffs = self.dW[:self.nW].reshape(ins, outs)
|
||||||
self.dbiases = self.dW[self.nW:].reshape(1, outs)
|
self.dbiases = self.dW[self.nW:].reshape(1, outs)
|
||||||
|
|
||||||
# he_normal initialization
|
self.coeffs.flat = init_he_uniform(self.nW, ins, outs)
|
||||||
s = np.sqrt(2 / ins)
|
|
||||||
self.coeffs.flat = np.random.normal(0, s, size=self.nW)
|
|
||||||
self.biases.flat = 0
|
self.biases.flat = 0
|
||||||
|
|
||||||
def make_shape(self, shape):
|
def make_shape(self, shape):
|
||||||
|
@ -432,12 +458,64 @@ class Model:
|
||||||
for i in range(len(denses)):
|
for i in range(len(denses)):
|
||||||
a, b = i, i + 1
|
a, b = i, i + 1
|
||||||
b_name = "dense_{}".format(b)
|
b_name = "dense_{}".format(b)
|
||||||
|
# TODO: write a Dense method instead of assigning directly
|
||||||
denses[a].coeffs = weights[b_name+'_W']
|
denses[a].coeffs = weights[b_name+'_W']
|
||||||
denses[a].biases = np.expand_dims(weights[b_name+'_b'], 0)
|
denses[a].biases = np.expand_dims(weights[b_name+'_b'], 0)
|
||||||
|
|
||||||
def save_weights(self, fn, overwrite=False):
|
def save_weights(self, fn, overwrite=False):
|
||||||
raise NotImplementedError("unimplemented", self)
|
raise NotImplementedError("unimplemented", self)
|
||||||
|
|
||||||
|
def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batchless'):
|
||||||
|
y = x
|
||||||
|
last_size = x.output_shape[0]
|
||||||
|
|
||||||
|
for d in range(depth):
|
||||||
|
size = width
|
||||||
|
|
||||||
|
if last_size != size:
|
||||||
|
y = y.feed(Dense(size))
|
||||||
|
|
||||||
|
if style == 'batchless':
|
||||||
|
skip = y
|
||||||
|
merger = Sum()
|
||||||
|
skip.feed(merger)
|
||||||
|
z_start = skip.feed(activation())
|
||||||
|
for i in range(multi):
|
||||||
|
z = z_start
|
||||||
|
for i in range(block):
|
||||||
|
if i > 0:
|
||||||
|
z = z.feed(activation())
|
||||||
|
z = z.feed(Dense(size))
|
||||||
|
z.feed(merger)
|
||||||
|
y = merger
|
||||||
|
elif style == 'onelesssum':
|
||||||
|
is_last = d + 1 == depth
|
||||||
|
needs_sum = not is_last or multi > 1
|
||||||
|
skip = y
|
||||||
|
if needs_sum:
|
||||||
|
merger = Sum()
|
||||||
|
if not is_last:
|
||||||
|
skip.feed(merger)
|
||||||
|
z_start = skip.feed(activation())
|
||||||
|
for i in range(multi):
|
||||||
|
z = z_start
|
||||||
|
for i in range(block):
|
||||||
|
if i > 0:
|
||||||
|
z = z.feed(activation())
|
||||||
|
z = z.feed(Dense(size))
|
||||||
|
if needs_sum:
|
||||||
|
z.feed(merger)
|
||||||
|
if needs_sum:
|
||||||
|
y = merger
|
||||||
|
else:
|
||||||
|
y = z
|
||||||
|
else:
|
||||||
|
raise Exception('unknown resnet style', style)
|
||||||
|
|
||||||
|
last_size = size
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import sys
|
import sys
|
||||||
lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs)
|
lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs)
|
||||||
|
@ -457,8 +535,7 @@ if __name__ == '__main__':
|
||||||
res_multi = 1, # normally 1 for plain resnet
|
res_multi = 1, # normally 1 for plain resnet
|
||||||
|
|
||||||
# style of resnet (order of layers, which layers, etc.)
|
# style of resnet (order of layers, which layers, etc.)
|
||||||
# only one is implemented so far
|
parallel_style = 'onelesssum',
|
||||||
parallel_style = 'batchless',
|
|
||||||
activation = 'gelu',
|
activation = 'gelu',
|
||||||
|
|
||||||
optim = 'adam',
|
optim = 'adam',
|
||||||
|
@ -475,9 +552,14 @@ if __name__ == '__main__':
|
||||||
# misc
|
# misc
|
||||||
batch_size = 64,
|
batch_size = 64,
|
||||||
init = 'he_normal',
|
init = 'he_normal',
|
||||||
loss = 'mse',
|
loss = SomethingElse(4/3),
|
||||||
restart_optim = False, # restarts also reset internal state of optimizer
|
log_fn = 'losses.npz',
|
||||||
unsafe = False, # aka gotta go fast mode
|
mloss = 'mse',
|
||||||
|
restart_optim = True, # restarts also reset internal state of optimizer
|
||||||
|
unsafe = True, # aka gotta go fast mode
|
||||||
|
train_compare = None,
|
||||||
|
#valid_compare = 0.0007159,
|
||||||
|
valid_compare = 0.0000946,
|
||||||
)
|
)
|
||||||
|
|
||||||
config.pprint()
|
config.pprint()
|
||||||
|
@ -501,34 +583,13 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
x = Input(shape=(input_samples,))
|
x = Input(shape=(input_samples,))
|
||||||
y = x
|
y = x
|
||||||
last_size = input_samples
|
|
||||||
|
|
||||||
activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, elu=Elu, gelu=GeluApprox)
|
activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, elu=Elu, gelu=GeluApprox)
|
||||||
activation = activations[config.activation]
|
activation = activations[config.activation]
|
||||||
|
y = multiresnet(y,
|
||||||
for blah in range(config.res_depth):
|
config.res_width, config.res_depth,
|
||||||
size = config.res_width
|
config.res_block, config.res_multi,
|
||||||
|
activation=activation)
|
||||||
if last_size != size:
|
if y.output_shape[0] != output_samples:
|
||||||
y = y.feed(Dense(size))
|
|
||||||
|
|
||||||
assert config.parallel_style == 'batchless'
|
|
||||||
skip = y
|
|
||||||
merger = Sum()
|
|
||||||
skip.feed(merger)
|
|
||||||
z_start = skip.feed(activation())
|
|
||||||
for i in range(config.res_multi):
|
|
||||||
z = z_start
|
|
||||||
for i in range(config.res_block):
|
|
||||||
if i > 0:
|
|
||||||
z = z.feed(activation())
|
|
||||||
z = z.feed(Dense(size))
|
|
||||||
z.feed(merger)
|
|
||||||
y = merger
|
|
||||||
|
|
||||||
last_size = size
|
|
||||||
|
|
||||||
if last_size != output_samples:
|
|
||||||
y = y.feed(Dense(output_samples))
|
y = y.feed(Dense(output_samples))
|
||||||
|
|
||||||
model = Model(x, y, unsafe=config.unsafe)
|
model = Model(x, y, unsafe=config.unsafe)
|
||||||
|
@ -559,12 +620,17 @@ if __name__ == '__main__':
|
||||||
else:
|
else:
|
||||||
raise Exception('unknown optimizer', config.optim)
|
raise Exception('unknown optimizer', config.optim)
|
||||||
|
|
||||||
if config.loss == 'mse':
|
def lookup_loss(maybe_name):
|
||||||
loss = Squared()
|
if isinstance(maybe_name, Loss):
|
||||||
elif config.loss == 'mshe': # mushy
|
return maybe_name
|
||||||
loss = SquaredHalved()
|
elif maybe_name == 'mse':
|
||||||
else:
|
return Squared()
|
||||||
raise Exception('unknown objective', config.loss)
|
elif maybe_name == 'mshe': # mushy
|
||||||
|
return SquaredHalved()
|
||||||
|
raise Exception('unknown objective', maybe_name)
|
||||||
|
|
||||||
|
loss = lookup_loss(config.loss)
|
||||||
|
mloss = lookup_loss(config.mloss) if config.mloss else loss
|
||||||
|
|
||||||
LR = config.LR
|
LR = config.LR
|
||||||
LRprod = 0.5**(1/config.LR_halve_every)
|
LRprod = 0.5**(1/config.LR_halve_every)
|
||||||
|
@ -574,21 +640,34 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
|
||||||
def measure_loss():
|
batch_losses = []
|
||||||
predicted = model.forward(inputs / x_scale)
|
train_losses = []
|
||||||
residual = predicted - outputs / y_scale
|
valid_losses = []
|
||||||
err = loss.mean(residual)
|
|
||||||
log("train loss", "{:11.7f}".format(err))
|
|
||||||
log("improvement", "{:+7.2f}%".format((0.0007031 / err - 1) * 100))
|
|
||||||
|
|
||||||
predicted = model.forward(valid_inputs / x_scale)
|
def measure_error():
|
||||||
residual = predicted - valid_outputs / y_scale
|
# log("weight mean", "{:11.7f}".format(np.mean(model.W)))
|
||||||
err = loss.mean(residual)
|
# log("weight var", "{:11.7f}".format(np.var(model.W)))
|
||||||
log("valid loss", "{:11.7f}".format(err))
|
|
||||||
log("improvement", "{:+7.2f}%".format((0.0007159 / err - 1) * 100))
|
def print_error(name, inputs, outputs, comparison=None):
|
||||||
|
predicted = model.forward(inputs)
|
||||||
|
residual = predicted - outputs
|
||||||
|
err = mloss.mean(residual)
|
||||||
|
log(name + " loss", "{:11.7f}".format(err))
|
||||||
|
if comparison:
|
||||||
|
log("improvement", "{:+7.2f}%".format((comparison / err - 1) * 100))
|
||||||
|
return err
|
||||||
|
|
||||||
|
train_err = print_error("train",
|
||||||
|
inputs / x_scale, outputs / y_scale,
|
||||||
|
config.train_compare)
|
||||||
|
valid_err = print_error("valid",
|
||||||
|
valid_inputs / x_scale, valid_outputs / y_scale,
|
||||||
|
config.valid_compare)
|
||||||
|
train_losses.append(train_err)
|
||||||
|
valid_losses.append(valid_err)
|
||||||
|
|
||||||
for i in range(config.restarts + 1):
|
for i in range(config.restarts + 1):
|
||||||
measure_loss()
|
measure_error()
|
||||||
|
|
||||||
if i > 0:
|
if i > 0:
|
||||||
log("restarting", i)
|
log("restarting", i)
|
||||||
|
@ -620,10 +699,12 @@ if __name__ == '__main__':
|
||||||
optim.update(dW, model.W)
|
optim.update(dW, model.W)
|
||||||
|
|
||||||
# note: we don't actually need this for training, only monitoring.
|
# note: we don't actually need this for training, only monitoring.
|
||||||
cumsum_loss += loss.mean(residual)
|
batch_loss = mloss.mean(residual)
|
||||||
log("average loss", "{:11.7f}".format(cumsum_loss / batch_count))
|
cumsum_loss += batch_loss
|
||||||
|
batch_losses.append(batch_loss)
|
||||||
|
#log("average loss", "{:11.7f}".format(cumsum_loss / batch_count))
|
||||||
|
|
||||||
measure_loss()
|
measure_error()
|
||||||
|
|
||||||
#if training:
|
#if training:
|
||||||
# model.save_weights(config.fn, overwrite=True)
|
# model.save_weights(config.fn, overwrite=True)
|
||||||
|
@ -636,3 +717,9 @@ if __name__ == '__main__':
|
||||||
P = model.forward(X) * y_scale
|
P = model.forward(X) * y_scale
|
||||||
log("truth", rgbcompare(a, b))
|
log("truth", rgbcompare(a, b))
|
||||||
log("network", np.squeeze(P))
|
log("network", np.squeeze(P))
|
||||||
|
|
||||||
|
if config.log_fn is not None:
|
||||||
|
np.savez_compressed(config.log_fn,
|
||||||
|
batch_losses=nfa(batch_losses),
|
||||||
|
train_losses=nfa(train_losses),
|
||||||
|
valid_losses=nfa(valid_losses))
|
||||||
|
|
Loading…
Reference in a new issue