This commit is contained in:
Connor Olding 2017-01-11 02:20:49 -08:00
parent 5e06190acc
commit 315213be6d

View file

@ -123,7 +123,7 @@ class Layer:
_layer_counters[kind] += 1 _layer_counters[kind] += 1
self.name = "{}_{}".format(kind, _layer_counters[kind]) self.name = "{}_{}".format(kind, _layer_counters[kind])
self.size = None # total weight count (if any) self.size = None # total weight count (if any)
self.unsafe = False # aka gotta go fast mode self.unsafe = False # disables assertions for better performance
def __str__(self): def __str__(self):
return self.name return self.name
@ -288,6 +288,20 @@ class Relu(Layer):
def dF(self, dY): def dF(self, dY):
return np.where(self.cond, dY, 0) return np.where(self.cond, dY, 0)
class Elu(Layer):
# paper: https://arxiv.org/abs/1511.07289
def __init__(self, alpha=1):
super().__init__()
self.alpha = nf(alpha)
def F(self, X):
self.cond = X >= 0
self.neg = np.exp(X) - 1
return np.where(self.cond, X, self.neg)
def dF(self, dY):
return dY * np.where(self.cond, 1, self.neg + 1)
class GeluApprox(Layer): class GeluApprox(Layer):
# paper: https://arxiv.org/abs/1606.08415 # paper: https://arxiv.org/abs/1606.08415
# plot: https://www.desmos.com/calculator/ydzgtccsld # plot: https://www.desmos.com/calculator/ydzgtccsld
@ -437,15 +451,15 @@ if __name__ == '__main__':
fn = 'ml/cie_mlp_min.h5', fn = 'ml/cie_mlp_min.h5',
# multi-residual network parameters # multi-residual network parameters
res_width = 12, res_width = 49,
res_depth = 3, res_depth = 1,
res_block = 2, # normally 2 for plain resnet res_block = 4, # normally 2 for plain resnet
res_multi = 4, # normally 1 for plain resnet res_multi = 1, # normally 1 for plain resnet
# style of resnet # style of resnet (order of layers, which layers, etc.)
# only one is implemented so far # only one is implemented so far
parallel_style = 'batchless', parallel_style = 'batchless',
activation = 'relu', activation = 'gelu',
optim = 'adam', optim = 'adam',
nesterov = False, # only used with SGD or Adam nesterov = False, # only used with SGD or Adam
@ -453,17 +467,21 @@ if __name__ == '__main__':
# learning parameters: SGD with restarts (kinda) # learning parameters: SGD with restarts (kinda)
LR = 1e-2, LR = 1e-2,
epochs = 6, epochs = 24,
LR_halve_every = 2, LR_halve_every = 16,
restarts = 3, restarts = 2,
LR_restart_advance = 3, LR_restart_advance = 16,
# misc # misc
batch_size = 64, batch_size = 64,
init = 'he_normal', init = 'he_normal',
loss = 'mse', loss = 'mse',
restart_optim = False, # restarts also reset internal state of optimizer
unsafe = False, # aka gotta go fast mode
) )
config.pprint()
# toy CIE-2000 data # toy CIE-2000 data
from ml.cie_mlp_data import rgbcompare, input_samples, output_samples, x_scale, y_scale from ml.cie_mlp_data import rgbcompare, input_samples, output_samples, x_scale, y_scale
@ -485,7 +503,7 @@ if __name__ == '__main__':
y = x y = x
last_size = input_samples last_size = input_samples
activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, gelu=GeluApprox) activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, elu=Elu, gelu=GeluApprox)
activation = activations[config.activation] activation = activations[config.activation]
for blah in range(config.res_depth): for blah in range(config.res_depth):
@ -513,7 +531,7 @@ if __name__ == '__main__':
if last_size != output_samples: if last_size != output_samples:
y = y.feed(Dense(output_samples)) y = y.feed(Dense(output_samples))
model = Model(x, y, unsafe=False) model = Model(x, y, unsafe=config.unsafe)
node_names = ' '.join([str(node) for node in model.ordered_nodes]) node_names = ' '.join([str(node) for node in model.ordered_nodes])
log('{} nodes'.format(len(model.ordered_nodes)), node_names) log('{} nodes'.format(len(model.ordered_nodes)), node_names)
@ -522,6 +540,12 @@ if __name__ == '__main__':
training = config.epochs > 0 and config.restarts >= 0 training = config.epochs > 0 and config.restarts >= 0
if not training: if not training:
assert config.res_width == 12
assert config.res_depth == 3
assert config.res_block == 2
assert config.res_multi == 4
assert config.activation == 'relu'
assert config.parallel_style == 'batchless'
model.load_weights(config.fn) model.load_weights(config.fn)
if config.optim == 'adam': if config.optim == 'adam':
@ -569,7 +593,8 @@ if __name__ == '__main__':
if i > 0: if i > 0:
log("restarting", i) log("restarting", i)
LR *= LRprod**config.LR_restart_advance LR *= LRprod**config.LR_restart_advance
optim.reset() if config.restart_optim:
optim.reset()
assert inputs.shape[0] % config.batch_size == 0, \ assert inputs.shape[0] % config.batch_size == 0, \
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction "inputs is not evenly divisible by batch_size" # TODO: lift this restriction