.
This commit is contained in:
parent
5e06190acc
commit
315213be6d
1 changed files with 39 additions and 14 deletions
51
optim_nn.py
51
optim_nn.py
|
@ -123,7 +123,7 @@ class Layer:
|
|||
_layer_counters[kind] += 1
|
||||
self.name = "{}_{}".format(kind, _layer_counters[kind])
|
||||
self.size = None # total weight count (if any)
|
||||
self.unsafe = False # aka gotta go fast mode
|
||||
self.unsafe = False # disables assertions for better performance
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
@ -288,6 +288,20 @@ class Relu(Layer):
|
|||
def dF(self, dY):
|
||||
return np.where(self.cond, dY, 0)
|
||||
|
||||
class Elu(Layer):
|
||||
# paper: https://arxiv.org/abs/1511.07289
|
||||
def __init__(self, alpha=1):
|
||||
super().__init__()
|
||||
self.alpha = nf(alpha)
|
||||
|
||||
def F(self, X):
|
||||
self.cond = X >= 0
|
||||
self.neg = np.exp(X) - 1
|
||||
return np.where(self.cond, X, self.neg)
|
||||
|
||||
def dF(self, dY):
|
||||
return dY * np.where(self.cond, 1, self.neg + 1)
|
||||
|
||||
class GeluApprox(Layer):
|
||||
# paper: https://arxiv.org/abs/1606.08415
|
||||
# plot: https://www.desmos.com/calculator/ydzgtccsld
|
||||
|
@ -437,15 +451,15 @@ if __name__ == '__main__':
|
|||
fn = 'ml/cie_mlp_min.h5',
|
||||
|
||||
# multi-residual network parameters
|
||||
res_width = 12,
|
||||
res_depth = 3,
|
||||
res_block = 2, # normally 2 for plain resnet
|
||||
res_multi = 4, # normally 1 for plain resnet
|
||||
res_width = 49,
|
||||
res_depth = 1,
|
||||
res_block = 4, # normally 2 for plain resnet
|
||||
res_multi = 1, # normally 1 for plain resnet
|
||||
|
||||
# style of resnet
|
||||
# style of resnet (order of layers, which layers, etc.)
|
||||
# only one is implemented so far
|
||||
parallel_style = 'batchless',
|
||||
activation = 'relu',
|
||||
activation = 'gelu',
|
||||
|
||||
optim = 'adam',
|
||||
nesterov = False, # only used with SGD or Adam
|
||||
|
@ -453,17 +467,21 @@ if __name__ == '__main__':
|
|||
|
||||
# learning parameters: SGD with restarts (kinda)
|
||||
LR = 1e-2,
|
||||
epochs = 6,
|
||||
LR_halve_every = 2,
|
||||
restarts = 3,
|
||||
LR_restart_advance = 3,
|
||||
epochs = 24,
|
||||
LR_halve_every = 16,
|
||||
restarts = 2,
|
||||
LR_restart_advance = 16,
|
||||
|
||||
# misc
|
||||
batch_size = 64,
|
||||
init = 'he_normal',
|
||||
loss = 'mse',
|
||||
restart_optim = False, # restarts also reset internal state of optimizer
|
||||
unsafe = False, # aka gotta go fast mode
|
||||
)
|
||||
|
||||
config.pprint()
|
||||
|
||||
# toy CIE-2000 data
|
||||
from ml.cie_mlp_data import rgbcompare, input_samples, output_samples, x_scale, y_scale
|
||||
|
||||
|
@ -485,7 +503,7 @@ if __name__ == '__main__':
|
|||
y = x
|
||||
last_size = input_samples
|
||||
|
||||
activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, gelu=GeluApprox)
|
||||
activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, elu=Elu, gelu=GeluApprox)
|
||||
activation = activations[config.activation]
|
||||
|
||||
for blah in range(config.res_depth):
|
||||
|
@ -513,7 +531,7 @@ if __name__ == '__main__':
|
|||
if last_size != output_samples:
|
||||
y = y.feed(Dense(output_samples))
|
||||
|
||||
model = Model(x, y, unsafe=False)
|
||||
model = Model(x, y, unsafe=config.unsafe)
|
||||
|
||||
node_names = ' '.join([str(node) for node in model.ordered_nodes])
|
||||
log('{} nodes'.format(len(model.ordered_nodes)), node_names)
|
||||
|
@ -522,6 +540,12 @@ if __name__ == '__main__':
|
|||
training = config.epochs > 0 and config.restarts >= 0
|
||||
|
||||
if not training:
|
||||
assert config.res_width == 12
|
||||
assert config.res_depth == 3
|
||||
assert config.res_block == 2
|
||||
assert config.res_multi == 4
|
||||
assert config.activation == 'relu'
|
||||
assert config.parallel_style == 'batchless'
|
||||
model.load_weights(config.fn)
|
||||
|
||||
if config.optim == 'adam':
|
||||
|
@ -569,6 +593,7 @@ if __name__ == '__main__':
|
|||
if i > 0:
|
||||
log("restarting", i)
|
||||
LR *= LRprod**config.LR_restart_advance
|
||||
if config.restart_optim:
|
||||
optim.reset()
|
||||
|
||||
assert inputs.shape[0] % config.batch_size == 0, \
|
||||
|
|
Loading…
Reference in a new issue