update mnist example with new features
This commit is contained in:
parent
d08b5b91a1
commit
87ffa014ca
1 changed files with 58 additions and 14 deletions
|
@ -17,33 +17,48 @@ if use_emnist:
|
|||
|
||||
sgdr = True
|
||||
restart_decay = 0.5
|
||||
CLR = SineCLR
|
||||
|
||||
n_dense = 0
|
||||
n_denses = 2
|
||||
new_dims = (28, 28)
|
||||
activation = GeluApprox
|
||||
|
||||
reg = None
|
||||
final_reg = None
|
||||
actreg_lamb = None
|
||||
|
||||
load_fn = None
|
||||
save_fn = 'emnist.h5'
|
||||
log_fn = 'emnist_losses.npz'
|
||||
|
||||
fn = 'emnist-balanced.npz'
|
||||
mnist_dim = 28
|
||||
mnist_classes = 47
|
||||
|
||||
else:
|
||||
lr = 0.0032
|
||||
epochs = 125
|
||||
starts = 5
|
||||
epochs = 60
|
||||
starts = 3
|
||||
bs = 200
|
||||
|
||||
activation = Relu
|
||||
|
||||
sgdr = False
|
||||
sgdr = True
|
||||
restart_decay = 0.5
|
||||
CLR = SineCLR
|
||||
|
||||
n_dense = 1
|
||||
n_dense = 2
|
||||
n_denses = 1
|
||||
new_dims = (4, 12)
|
||||
activation = Relu
|
||||
|
||||
reg = L1L2(3.2e-5, 3.2e-4)
|
||||
final_reg = L1L2(3.2e-5, 1e-3)
|
||||
actreg_lamb = None # 1e-3
|
||||
|
||||
load_fn = None
|
||||
save_fn = 'mnist.h5'
|
||||
log_fn = 'mnist_losses.npz'
|
||||
|
||||
fn = 'mnist.npz'
|
||||
mnist_dim = 28
|
||||
mnist_classes = 10
|
||||
|
@ -74,23 +89,39 @@ def get_mnist(fn='mnist.npz'):
|
|||
|
||||
inputs, outputs, valid_inputs, valid_outputs = get_mnist(fn)
|
||||
|
||||
def actreg(y):
|
||||
if not actreg_lamb:
|
||||
return y
|
||||
lamb = actreg_lamb # * np.prod(y.output_shape)
|
||||
reg = SaturateRelu(lamb)
|
||||
act = ActivityRegularizer(reg)
|
||||
reg.lamb_orig = reg.lamb # HACK
|
||||
return y.feed(act)
|
||||
|
||||
x = Input(shape=inputs.shape[1:])
|
||||
y = x
|
||||
|
||||
y = y.feed(Reshape(new_shape=(mnist_dim, mnist_dim)))
|
||||
for i in range(n_denses):
|
||||
if i > 0:
|
||||
y = actreg(y)
|
||||
y = y.feed(activation())
|
||||
y = y.feed(Denses(new_dims[0], axis=0, init=init_he_normal))
|
||||
y = y.feed(Denses(new_dims[1], axis=1, init=init_he_normal))
|
||||
y = y.feed(Denses(new_dims[0], axis=0, init=init_he_normal,
|
||||
reg_w=reg, reg_b=reg))
|
||||
y = y.feed(Denses(new_dims[1], axis=1, init=init_he_normal,
|
||||
reg_w=reg, reg_b=reg))
|
||||
y = y.feed(Flatten())
|
||||
for i in range(n_dense):
|
||||
if i > 0:
|
||||
y = actreg(y)
|
||||
y = y.feed(activation())
|
||||
y = y.feed(Dense(y.output_shape[0], init=init_he_normal))
|
||||
y = y.feed(Dense(y.output_shape[0], init=init_he_normal,
|
||||
reg_w=reg, reg_b=reg))
|
||||
y = actreg(y)
|
||||
y = y.feed(activation())
|
||||
|
||||
y = y.feed(Dense(mnist_classes, init=init_glorot_uniform))
|
||||
y = y.feed(Dense(mnist_classes, init=init_glorot_uniform,
|
||||
reg_w=final_reg, reg_b=final_reg))
|
||||
y = y.feed(Softmax())
|
||||
|
||||
model = Model(x, y, unsafe=True)
|
||||
|
@ -101,10 +132,8 @@ if sgdr:
|
|||
restarts=starts-1, restart_decay=restart_decay,
|
||||
expando=lambda i:0)
|
||||
else:
|
||||
# learner = TriangularCLR(optim, epochs=epochs, lower_rate=0, upper_rate=lr,
|
||||
# frequency=epochs//starts)
|
||||
learner = SineCLR(optim, epochs=epochs, lower_rate=0, upper_rate=lr,
|
||||
frequency=epochs//starts)
|
||||
learner = CLR(optim, epochs=epochs, lower_rate=0, upper_rate=lr,
|
||||
frequency=epochs//starts)
|
||||
|
||||
loss = CategoricalCrossentropy()
|
||||
mloss = Accuracy()
|
||||
|
@ -113,6 +142,11 @@ ritual = Ritual(learner=learner, loss=loss, mloss=mloss)
|
|||
#ritual = NoisyRitual(learner=learner, loss=loss, mloss=mloss,
|
||||
# input_noise=1e-1, output_noise=3.2e-2, gradient_noise=1e-1)
|
||||
|
||||
for node in model.ordered_nodes:
|
||||
children = [str(n) for n in node.children]
|
||||
if children:
|
||||
sep = '->'
|
||||
print(str(node) + sep + ('\n' + str(node) + sep).join(children))
|
||||
log('parameters', model.param_count)
|
||||
|
||||
ritual.prepare(model)
|
||||
|
@ -151,6 +185,12 @@ def measure_error(quiet=False):
|
|||
measure_error()
|
||||
|
||||
while learner.next():
|
||||
act_t = (learner.epoch - 1) / (learner.epochs - 1)
|
||||
if actreg_lamb:
|
||||
for node in model.ordered_nodes:
|
||||
if isinstance(node, ActivityRegularizer):
|
||||
node.reg.lamb = act_t * node.reg.lamb_orig # HACK
|
||||
|
||||
indices = np.arange(inputs.shape[0])
|
||||
np.random.shuffle(indices)
|
||||
shuffled_inputs = inputs[indices]
|
||||
|
@ -174,6 +214,10 @@ while learner.next():
|
|||
if not measure_every_epoch:
|
||||
measure_error()
|
||||
|
||||
if save_fn is not None:
|
||||
log('saving weights', save_fn)
|
||||
model.save_weights(save_fn, overwrite=True)
|
||||
|
||||
if log_fn:
|
||||
log('saving losses', log_fn)
|
||||
np.savez_compressed(log_fn,
|
||||
|
|
Loading…
Reference in a new issue