.
This commit is contained in:
parent
5e06190acc
commit
315213be6d
1 changed files with 39 additions and 14 deletions
53
optim_nn.py
53
optim_nn.py
|
@ -123,7 +123,7 @@ class Layer:
|
||||||
_layer_counters[kind] += 1
|
_layer_counters[kind] += 1
|
||||||
self.name = "{}_{}".format(kind, _layer_counters[kind])
|
self.name = "{}_{}".format(kind, _layer_counters[kind])
|
||||||
self.size = None # total weight count (if any)
|
self.size = None # total weight count (if any)
|
||||||
self.unsafe = False # aka gotta go fast mode
|
self.unsafe = False # disables assertions for better performance
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
@ -288,6 +288,20 @@ class Relu(Layer):
|
||||||
def dF(self, dY):
|
def dF(self, dY):
|
||||||
return np.where(self.cond, dY, 0)
|
return np.where(self.cond, dY, 0)
|
||||||
|
|
||||||
|
class Elu(Layer):
|
||||||
|
# paper: https://arxiv.org/abs/1511.07289
|
||||||
|
def __init__(self, alpha=1):
|
||||||
|
super().__init__()
|
||||||
|
self.alpha = nf(alpha)
|
||||||
|
|
||||||
|
def F(self, X):
|
||||||
|
self.cond = X >= 0
|
||||||
|
self.neg = np.exp(X) - 1
|
||||||
|
return np.where(self.cond, X, self.neg)
|
||||||
|
|
||||||
|
def dF(self, dY):
|
||||||
|
return dY * np.where(self.cond, 1, self.neg + 1)
|
||||||
|
|
||||||
class GeluApprox(Layer):
|
class GeluApprox(Layer):
|
||||||
# paper: https://arxiv.org/abs/1606.08415
|
# paper: https://arxiv.org/abs/1606.08415
|
||||||
# plot: https://www.desmos.com/calculator/ydzgtccsld
|
# plot: https://www.desmos.com/calculator/ydzgtccsld
|
||||||
|
@ -437,15 +451,15 @@ if __name__ == '__main__':
|
||||||
fn = 'ml/cie_mlp_min.h5',
|
fn = 'ml/cie_mlp_min.h5',
|
||||||
|
|
||||||
# multi-residual network parameters
|
# multi-residual network parameters
|
||||||
res_width = 12,
|
res_width = 49,
|
||||||
res_depth = 3,
|
res_depth = 1,
|
||||||
res_block = 2, # normally 2 for plain resnet
|
res_block = 4, # normally 2 for plain resnet
|
||||||
res_multi = 4, # normally 1 for plain resnet
|
res_multi = 1, # normally 1 for plain resnet
|
||||||
|
|
||||||
# style of resnet
|
# style of resnet (order of layers, which layers, etc.)
|
||||||
# only one is implemented so far
|
# only one is implemented so far
|
||||||
parallel_style = 'batchless',
|
parallel_style = 'batchless',
|
||||||
activation = 'relu',
|
activation = 'gelu',
|
||||||
|
|
||||||
optim = 'adam',
|
optim = 'adam',
|
||||||
nesterov = False, # only used with SGD or Adam
|
nesterov = False, # only used with SGD or Adam
|
||||||
|
@ -453,17 +467,21 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# learning parameters: SGD with restarts (kinda)
|
# learning parameters: SGD with restarts (kinda)
|
||||||
LR = 1e-2,
|
LR = 1e-2,
|
||||||
epochs = 6,
|
epochs = 24,
|
||||||
LR_halve_every = 2,
|
LR_halve_every = 16,
|
||||||
restarts = 3,
|
restarts = 2,
|
||||||
LR_restart_advance = 3,
|
LR_restart_advance = 16,
|
||||||
|
|
||||||
# misc
|
# misc
|
||||||
batch_size = 64,
|
batch_size = 64,
|
||||||
init = 'he_normal',
|
init = 'he_normal',
|
||||||
loss = 'mse',
|
loss = 'mse',
|
||||||
|
restart_optim = False, # restarts also reset internal state of optimizer
|
||||||
|
unsafe = False, # aka gotta go fast mode
|
||||||
)
|
)
|
||||||
|
|
||||||
|
config.pprint()
|
||||||
|
|
||||||
# toy CIE-2000 data
|
# toy CIE-2000 data
|
||||||
from ml.cie_mlp_data import rgbcompare, input_samples, output_samples, x_scale, y_scale
|
from ml.cie_mlp_data import rgbcompare, input_samples, output_samples, x_scale, y_scale
|
||||||
|
|
||||||
|
@ -485,7 +503,7 @@ if __name__ == '__main__':
|
||||||
y = x
|
y = x
|
||||||
last_size = input_samples
|
last_size = input_samples
|
||||||
|
|
||||||
activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, gelu=GeluApprox)
|
activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, elu=Elu, gelu=GeluApprox)
|
||||||
activation = activations[config.activation]
|
activation = activations[config.activation]
|
||||||
|
|
||||||
for blah in range(config.res_depth):
|
for blah in range(config.res_depth):
|
||||||
|
@ -513,7 +531,7 @@ if __name__ == '__main__':
|
||||||
if last_size != output_samples:
|
if last_size != output_samples:
|
||||||
y = y.feed(Dense(output_samples))
|
y = y.feed(Dense(output_samples))
|
||||||
|
|
||||||
model = Model(x, y, unsafe=False)
|
model = Model(x, y, unsafe=config.unsafe)
|
||||||
|
|
||||||
node_names = ' '.join([str(node) for node in model.ordered_nodes])
|
node_names = ' '.join([str(node) for node in model.ordered_nodes])
|
||||||
log('{} nodes'.format(len(model.ordered_nodes)), node_names)
|
log('{} nodes'.format(len(model.ordered_nodes)), node_names)
|
||||||
|
@ -522,6 +540,12 @@ if __name__ == '__main__':
|
||||||
training = config.epochs > 0 and config.restarts >= 0
|
training = config.epochs > 0 and config.restarts >= 0
|
||||||
|
|
||||||
if not training:
|
if not training:
|
||||||
|
assert config.res_width == 12
|
||||||
|
assert config.res_depth == 3
|
||||||
|
assert config.res_block == 2
|
||||||
|
assert config.res_multi == 4
|
||||||
|
assert config.activation == 'relu'
|
||||||
|
assert config.parallel_style == 'batchless'
|
||||||
model.load_weights(config.fn)
|
model.load_weights(config.fn)
|
||||||
|
|
||||||
if config.optim == 'adam':
|
if config.optim == 'adam':
|
||||||
|
@ -569,7 +593,8 @@ if __name__ == '__main__':
|
||||||
if i > 0:
|
if i > 0:
|
||||||
log("restarting", i)
|
log("restarting", i)
|
||||||
LR *= LRprod**config.LR_restart_advance
|
LR *= LRprod**config.LR_restart_advance
|
||||||
optim.reset()
|
if config.restart_optim:
|
||||||
|
optim.reset()
|
||||||
|
|
||||||
assert inputs.shape[0] % config.batch_size == 0, \
|
assert inputs.shape[0] % config.batch_size == 0, \
|
||||||
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
|
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
|
||||||
|
|
Loading…
Reference in a new issue