.
This commit is contained in:
parent
315213be6d
commit
02e03ad85e
1 changed files with 145 additions and 58 deletions
203
optim_nn.py
203
optim_nn.py
|
@ -6,8 +6,22 @@ nfa = lambda x: np.array(x, dtype=nf)
|
|||
ni = np.int
|
||||
nia = lambda x: np.array(x, dtype=ni)
|
||||
|
||||
from scipy.special import expit as sigmoid
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
# Initializations
|
||||
|
||||
# note: these are currently only implemented for 2D shapes.
|
||||
|
||||
def init_he_normal(size, ins, outs):
|
||||
s = np.sqrt(2 / ins)
|
||||
return np.random.normal(0, s, size=size)
|
||||
|
||||
def init_he_uniform(size, ins, outs):
|
||||
s = np.sqrt(6 / ins)
|
||||
return np.random.uniform(-s, s, size=size)
|
||||
|
||||
# Loss functions
|
||||
|
||||
class Loss:
|
||||
|
@ -32,6 +46,21 @@ class SquaredHalved(Loss):
|
|||
def df(self, r):
|
||||
return r
|
||||
|
||||
class SomethingElse(Loss):
|
||||
# generalizes Absolute and SquaredHalved
|
||||
# plot: https://www.desmos.com/calculator/fagjg9vuz7
|
||||
def __init__(self, a=4/3):
|
||||
assert 1 <= a <= 2, "parameter out of range"
|
||||
self.a = nf(a / 2)
|
||||
self.b = nf(2 / a)
|
||||
self.c = nf(2 / a - 1)
|
||||
|
||||
def f(self, r):
|
||||
return self.a * np.abs(r)**self.b
|
||||
|
||||
def df(self, r):
|
||||
return np.sign(r) * np.abs(r)**self.c
|
||||
|
||||
# Optimizers
|
||||
|
||||
class Optimizer:
|
||||
|
@ -102,8 +131,8 @@ class Adam(Optimizer):
|
|||
self.b1_t *= self.b1
|
||||
self.b2_t *= self.b2
|
||||
|
||||
self.mt = self.b1 * self.mt + (1 - self.b1) * dW
|
||||
self.vt = self.b2 * self.vt + (1 - self.b2) * dW * dW
|
||||
self.mt[:] = self.b1 * self.mt + (1 - self.b1) * dW
|
||||
self.vt[:] = self.b2 * self.vt + (1 - self.b2) * dW * dW
|
||||
|
||||
return -self.alpha * (self.mt / (1 - self.b1_t)) \
|
||||
/ np.sqrt((self.vt / (1 - self.b2_t)) + self.eps)
|
||||
|
@ -306,7 +335,6 @@ class GeluApprox(Layer):
|
|||
# paper: https://arxiv.org/abs/1606.08415
|
||||
# plot: https://www.desmos.com/calculator/ydzgtccsld
|
||||
def F(self, X):
|
||||
from scipy.special import expit as sigmoid
|
||||
self.a = 1.704 * X
|
||||
self.sig = sigmoid(self.a)
|
||||
return X * self.sig
|
||||
|
@ -331,9 +359,7 @@ class Dense(Layer):
|
|||
self.dcoeffs = self.dW[:self.nW].reshape(ins, outs)
|
||||
self.dbiases = self.dW[self.nW:].reshape(1, outs)
|
||||
|
||||
# he_normal initialization
|
||||
s = np.sqrt(2 / ins)
|
||||
self.coeffs.flat = np.random.normal(0, s, size=self.nW)
|
||||
self.coeffs.flat = init_he_uniform(self.nW, ins, outs)
|
||||
self.biases.flat = 0
|
||||
|
||||
def make_shape(self, shape):
|
||||
|
@ -432,12 +458,64 @@ class Model:
|
|||
for i in range(len(denses)):
|
||||
a, b = i, i + 1
|
||||
b_name = "dense_{}".format(b)
|
||||
# TODO: write a Dense method instead of assigning directly
|
||||
denses[a].coeffs = weights[b_name+'_W']
|
||||
denses[a].biases = np.expand_dims(weights[b_name+'_b'], 0)
|
||||
|
||||
def save_weights(self, fn, overwrite=False):
|
||||
raise NotImplementedError("unimplemented", self)
|
||||
|
||||
def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batchless'):
|
||||
y = x
|
||||
last_size = x.output_shape[0]
|
||||
|
||||
for d in range(depth):
|
||||
size = width
|
||||
|
||||
if last_size != size:
|
||||
y = y.feed(Dense(size))
|
||||
|
||||
if style == 'batchless':
|
||||
skip = y
|
||||
merger = Sum()
|
||||
skip.feed(merger)
|
||||
z_start = skip.feed(activation())
|
||||
for i in range(multi):
|
||||
z = z_start
|
||||
for i in range(block):
|
||||
if i > 0:
|
||||
z = z.feed(activation())
|
||||
z = z.feed(Dense(size))
|
||||
z.feed(merger)
|
||||
y = merger
|
||||
elif style == 'onelesssum':
|
||||
is_last = d + 1 == depth
|
||||
needs_sum = not is_last or multi > 1
|
||||
skip = y
|
||||
if needs_sum:
|
||||
merger = Sum()
|
||||
if not is_last:
|
||||
skip.feed(merger)
|
||||
z_start = skip.feed(activation())
|
||||
for i in range(multi):
|
||||
z = z_start
|
||||
for i in range(block):
|
||||
if i > 0:
|
||||
z = z.feed(activation())
|
||||
z = z.feed(Dense(size))
|
||||
if needs_sum:
|
||||
z.feed(merger)
|
||||
if needs_sum:
|
||||
y = merger
|
||||
else:
|
||||
y = z
|
||||
else:
|
||||
raise Exception('unknown resnet style', style)
|
||||
|
||||
last_size = size
|
||||
|
||||
return y
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs)
|
||||
|
@ -457,8 +535,7 @@ if __name__ == '__main__':
|
|||
res_multi = 1, # normally 1 for plain resnet
|
||||
|
||||
# style of resnet (order of layers, which layers, etc.)
|
||||
# only one is implemented so far
|
||||
parallel_style = 'batchless',
|
||||
parallel_style = 'onelesssum',
|
||||
activation = 'gelu',
|
||||
|
||||
optim = 'adam',
|
||||
|
@ -475,9 +552,14 @@ if __name__ == '__main__':
|
|||
# misc
|
||||
batch_size = 64,
|
||||
init = 'he_normal',
|
||||
loss = 'mse',
|
||||
restart_optim = False, # restarts also reset internal state of optimizer
|
||||
unsafe = False, # aka gotta go fast mode
|
||||
loss = SomethingElse(4/3),
|
||||
log_fn = 'losses.npz',
|
||||
mloss = 'mse',
|
||||
restart_optim = True, # restarts also reset internal state of optimizer
|
||||
unsafe = True, # aka gotta go fast mode
|
||||
train_compare = None,
|
||||
#valid_compare = 0.0007159,
|
||||
valid_compare = 0.0000946,
|
||||
)
|
||||
|
||||
config.pprint()
|
||||
|
@ -501,34 +583,13 @@ if __name__ == '__main__':
|
|||
|
||||
x = Input(shape=(input_samples,))
|
||||
y = x
|
||||
last_size = input_samples
|
||||
|
||||
activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, elu=Elu, gelu=GeluApprox)
|
||||
activation = activations[config.activation]
|
||||
|
||||
for blah in range(config.res_depth):
|
||||
size = config.res_width
|
||||
|
||||
if last_size != size:
|
||||
y = y.feed(Dense(size))
|
||||
|
||||
assert config.parallel_style == 'batchless'
|
||||
skip = y
|
||||
merger = Sum()
|
||||
skip.feed(merger)
|
||||
z_start = skip.feed(activation())
|
||||
for i in range(config.res_multi):
|
||||
z = z_start
|
||||
for i in range(config.res_block):
|
||||
if i > 0:
|
||||
z = z.feed(activation())
|
||||
z = z.feed(Dense(size))
|
||||
z.feed(merger)
|
||||
y = merger
|
||||
|
||||
last_size = size
|
||||
|
||||
if last_size != output_samples:
|
||||
y = multiresnet(y,
|
||||
config.res_width, config.res_depth,
|
||||
config.res_block, config.res_multi,
|
||||
activation=activation)
|
||||
if y.output_shape[0] != output_samples:
|
||||
y = y.feed(Dense(output_samples))
|
||||
|
||||
model = Model(x, y, unsafe=config.unsafe)
|
||||
|
@ -559,12 +620,17 @@ if __name__ == '__main__':
|
|||
else:
|
||||
raise Exception('unknown optimizer', config.optim)
|
||||
|
||||
if config.loss == 'mse':
|
||||
loss = Squared()
|
||||
elif config.loss == 'mshe': # mushy
|
||||
loss = SquaredHalved()
|
||||
else:
|
||||
raise Exception('unknown objective', config.loss)
|
||||
def lookup_loss(maybe_name):
|
||||
if isinstance(maybe_name, Loss):
|
||||
return maybe_name
|
||||
elif maybe_name == 'mse':
|
||||
return Squared()
|
||||
elif maybe_name == 'mshe': # mushy
|
||||
return SquaredHalved()
|
||||
raise Exception('unknown objective', maybe_name)
|
||||
|
||||
loss = lookup_loss(config.loss)
|
||||
mloss = lookup_loss(config.mloss) if config.mloss else loss
|
||||
|
||||
LR = config.LR
|
||||
LRprod = 0.5**(1/config.LR_halve_every)
|
||||
|
@ -574,21 +640,34 @@ if __name__ == '__main__':
|
|||
|
||||
# Training
|
||||
|
||||
def measure_loss():
|
||||
predicted = model.forward(inputs / x_scale)
|
||||
residual = predicted - outputs / y_scale
|
||||
err = loss.mean(residual)
|
||||
log("train loss", "{:11.7f}".format(err))
|
||||
log("improvement", "{:+7.2f}%".format((0.0007031 / err - 1) * 100))
|
||||
batch_losses = []
|
||||
train_losses = []
|
||||
valid_losses = []
|
||||
|
||||
predicted = model.forward(valid_inputs / x_scale)
|
||||
residual = predicted - valid_outputs / y_scale
|
||||
err = loss.mean(residual)
|
||||
log("valid loss", "{:11.7f}".format(err))
|
||||
log("improvement", "{:+7.2f}%".format((0.0007159 / err - 1) * 100))
|
||||
def measure_error():
|
||||
# log("weight mean", "{:11.7f}".format(np.mean(model.W)))
|
||||
# log("weight var", "{:11.7f}".format(np.var(model.W)))
|
||||
|
||||
def print_error(name, inputs, outputs, comparison=None):
|
||||
predicted = model.forward(inputs)
|
||||
residual = predicted - outputs
|
||||
err = mloss.mean(residual)
|
||||
log(name + " loss", "{:11.7f}".format(err))
|
||||
if comparison:
|
||||
log("improvement", "{:+7.2f}%".format((comparison / err - 1) * 100))
|
||||
return err
|
||||
|
||||
train_err = print_error("train",
|
||||
inputs / x_scale, outputs / y_scale,
|
||||
config.train_compare)
|
||||
valid_err = print_error("valid",
|
||||
valid_inputs / x_scale, valid_outputs / y_scale,
|
||||
config.valid_compare)
|
||||
train_losses.append(train_err)
|
||||
valid_losses.append(valid_err)
|
||||
|
||||
for i in range(config.restarts + 1):
|
||||
measure_loss()
|
||||
measure_error()
|
||||
|
||||
if i > 0:
|
||||
log("restarting", i)
|
||||
|
@ -620,10 +699,12 @@ if __name__ == '__main__':
|
|||
optim.update(dW, model.W)
|
||||
|
||||
# note: we don't actually need this for training, only monitoring.
|
||||
cumsum_loss += loss.mean(residual)
|
||||
log("average loss", "{:11.7f}".format(cumsum_loss / batch_count))
|
||||
batch_loss = mloss.mean(residual)
|
||||
cumsum_loss += batch_loss
|
||||
batch_losses.append(batch_loss)
|
||||
#log("average loss", "{:11.7f}".format(cumsum_loss / batch_count))
|
||||
|
||||
measure_loss()
|
||||
measure_error()
|
||||
|
||||
#if training:
|
||||
# model.save_weights(config.fn, overwrite=True)
|
||||
|
@ -636,3 +717,9 @@ if __name__ == '__main__':
|
|||
P = model.forward(X) * y_scale
|
||||
log("truth", rgbcompare(a, b))
|
||||
log("network", np.squeeze(P))
|
||||
|
||||
if config.log_fn is not None:
|
||||
np.savez_compressed(config.log_fn,
|
||||
batch_losses=nfa(batch_losses),
|
||||
train_losses=nfa(train_losses),
|
||||
valid_losses=nfa(valid_losses))
|
||||
|
|
Loading…
Reference in a new issue