2017-02-27 14:52:39 -08:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
2017-06-25 17:16:51 -07:00
|
|
|
from onn import *
|
|
|
|
from onn_core import _f
|
2017-07-01 19:53:31 -07:00
|
|
|
from dotmap import DotMap
|
2017-02-27 14:52:39 -08:00
|
|
|
|
2017-07-11 05:44:26 -07:00
|
|
|
lower_priority()
|
2017-07-22 21:23:57 -07:00
|
|
|
np.random.seed(42069)
|
2017-02-27 14:52:39 -08:00
|
|
|
|
2017-02-27 16:36:04 -08:00
|
|
|
measure_every_epoch = True
|
|
|
|
|
2017-08-02 20:35:02 -07:00
|
|
|
target_boost = lambda y: y
|
|
|
|
|
|
|
|
use_emnist = False
|
2017-03-12 17:41:18 -07:00
|
|
|
if use_emnist:
|
2017-08-02 20:35:02 -07:00
|
|
|
lr = 1.0
|
2017-03-12 17:41:18 -07:00
|
|
|
epochs = 48
|
|
|
|
starts = 2
|
2017-06-17 09:46:39 -07:00
|
|
|
bs = 400
|
2017-03-12 17:41:18 -07:00
|
|
|
|
2017-04-11 05:49:49 -07:00
|
|
|
learner_class = SGDR
|
2017-03-12 17:41:18 -07:00
|
|
|
restart_decay = 0.5
|
|
|
|
|
2017-06-17 09:46:39 -07:00
|
|
|
n_dense = 2
|
|
|
|
n_denses = 0
|
2017-03-12 17:41:18 -07:00
|
|
|
new_dims = (28, 28)
|
|
|
|
activation = GeluApprox
|
2017-07-22 21:23:57 -07:00
|
|
|
output_activation = Softmax
|
|
|
|
normalize = True
|
|
|
|
|
2017-08-02 20:35:02 -07:00
|
|
|
optim = MomentumClip(mu=0.7, nesterov=True)
|
|
|
|
restart_optim = False
|
2017-03-12 17:41:18 -07:00
|
|
|
|
2017-07-11 05:11:47 -07:00
|
|
|
reg = None # L1L2(2.0e-5, 1.0e-4)
|
|
|
|
final_reg = None # L1L2(2.0e-5, 1.0e-4)
|
|
|
|
dropout = 0.33
|
2017-04-10 21:48:53 -07:00
|
|
|
actreg_lamb = None
|
|
|
|
|
|
|
|
load_fn = None
|
|
|
|
save_fn = 'emnist.h5'
|
2017-03-12 17:41:18 -07:00
|
|
|
log_fn = 'emnist_losses.npz'
|
2017-04-10 21:48:53 -07:00
|
|
|
|
2017-03-12 17:41:18 -07:00
|
|
|
fn = 'emnist-balanced.npz'
|
|
|
|
mnist_dim = 28
|
|
|
|
mnist_classes = 47
|
|
|
|
|
|
|
|
else:
|
2017-08-02 20:35:02 -07:00
|
|
|
lr = 0.01
|
2017-04-10 21:48:53 -07:00
|
|
|
epochs = 60
|
|
|
|
starts = 3
|
2017-04-11 05:49:49 -07:00
|
|
|
bs = 500
|
2017-03-12 17:41:18 -07:00
|
|
|
|
2017-07-22 21:23:57 -07:00
|
|
|
learner_class = SGDR
|
2017-03-12 17:41:18 -07:00
|
|
|
restart_decay = 0.5
|
|
|
|
|
2017-07-22 21:23:57 -07:00
|
|
|
n_dense = 2
|
|
|
|
n_denses = 1
|
2017-03-12 17:41:18 -07:00
|
|
|
new_dims = (4, 12)
|
2017-07-22 21:23:57 -07:00
|
|
|
activation = GeluApprox
|
|
|
|
output_activation = Softmax
|
|
|
|
normalize = True
|
|
|
|
|
|
|
|
optim = MomentumClip(0.8, 0.8)
|
2017-08-02 20:35:02 -07:00
|
|
|
restart_optim = False
|
2017-03-12 17:41:18 -07:00
|
|
|
|
2017-07-03 02:47:31 -07:00
|
|
|
reg = None # L1L2(1e-6, 1e-5) # L1L2(3.2e-5, 3.2e-4)
|
|
|
|
final_reg = None # L1L2(1e-6, 1e-5) # L1L2(3.2e-5, 1e-3)
|
2017-07-01 19:55:19 -07:00
|
|
|
dropout = None # 0.05
|
2017-04-11 05:49:49 -07:00
|
|
|
actreg_lamb = None #1e-4
|
2017-04-10 21:48:53 -07:00
|
|
|
|
|
|
|
load_fn = None
|
2017-07-22 21:23:57 -07:00
|
|
|
save_fn = 'mnist.h5'
|
|
|
|
log_fn = 'mnist_losses.npz'
|
2017-04-10 21:48:53 -07:00
|
|
|
|
2017-03-12 17:41:18 -07:00
|
|
|
fn = 'mnist.npz'
|
|
|
|
mnist_dim = 28
|
|
|
|
mnist_classes = 10
|
|
|
|
|
2017-02-27 14:52:39 -08:00
|
|
|
def get_mnist(fn='mnist.npz'):
|
|
|
|
import os.path
|
2017-03-12 17:41:18 -07:00
|
|
|
if fn == 'mnist.npz' and not os.path.exists(fn):
|
2017-02-27 14:52:39 -08:00
|
|
|
from keras.datasets import mnist
|
|
|
|
from keras.utils.np_utils import to_categorical
|
|
|
|
(X_train, y_train), (X_test, y_test) = mnist.load_data()
|
|
|
|
X_train = X_train.reshape(X_train.shape[0], 1, mnist_dim, mnist_dim)
|
|
|
|
X_test = X_test.reshape(X_test.shape[0], 1, mnist_dim, mnist_dim)
|
|
|
|
X_train = X_train.astype('float32') / 255
|
|
|
|
X_test = X_test.astype('float32') / 255
|
|
|
|
Y_train = to_categorical(y_train, mnist_classes)
|
|
|
|
Y_test = to_categorical(y_test, mnist_classes)
|
|
|
|
np.savez_compressed(fn,
|
|
|
|
X_train=X_train,
|
|
|
|
Y_train=Y_train,
|
|
|
|
X_test=X_test,
|
|
|
|
Y_test=Y_test)
|
|
|
|
lament("mnist successfully saved to", fn)
|
|
|
|
lament("please re-run this program to continue")
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
with np.load(fn) as f:
|
|
|
|
return f['X_train'], f['Y_train'], f['X_test'], f['Y_test']
|
|
|
|
|
2017-03-12 17:41:18 -07:00
|
|
|
inputs, outputs, valid_inputs, valid_outputs = get_mnist(fn)
|
2017-02-27 14:52:39 -08:00
|
|
|
|
2017-07-22 21:23:57 -07:00
|
|
|
outputs = target_boost(outputs)
|
|
|
|
valid_outputs = target_boost(valid_outputs)
|
|
|
|
|
2017-04-11 03:32:48 -07:00
|
|
|
def regulate(y):
|
|
|
|
if actreg_lamb:
|
2017-04-11 05:49:49 -07:00
|
|
|
assert activation == Relu, activation
|
2017-04-11 03:32:48 -07:00
|
|
|
lamb = actreg_lamb # * np.prod(y.output_shape)
|
|
|
|
reg = SaturateRelu(lamb)
|
|
|
|
act = ActivityRegularizer(reg)
|
|
|
|
reg.lamb_orig = reg.lamb # HACK
|
|
|
|
y = y.feed(act)
|
2017-07-22 21:23:57 -07:00
|
|
|
if normalize:
|
|
|
|
y = y.feed(LayerNorm())
|
2017-04-11 03:32:48 -07:00
|
|
|
if dropout:
|
|
|
|
y = y.feed(Dropout(dropout))
|
|
|
|
return y
|
2017-04-10 21:48:53 -07:00
|
|
|
|
2017-02-27 14:52:39 -08:00
|
|
|
x = Input(shape=inputs.shape[1:])
|
|
|
|
y = x
|
|
|
|
|
2017-03-12 17:41:18 -07:00
|
|
|
y = y.feed(Reshape(new_shape=(mnist_dim, mnist_dim)))
|
|
|
|
for i in range(n_denses):
|
|
|
|
if i > 0:
|
2017-04-11 03:32:48 -07:00
|
|
|
y = regulate(y)
|
2017-03-12 17:41:18 -07:00
|
|
|
y = y.feed(activation())
|
2017-04-10 21:48:53 -07:00
|
|
|
y = y.feed(Denses(new_dims[0], axis=0, init=init_he_normal,
|
|
|
|
reg_w=reg, reg_b=reg))
|
|
|
|
y = y.feed(Denses(new_dims[1], axis=1, init=init_he_normal,
|
|
|
|
reg_w=reg, reg_b=reg))
|
2017-02-27 14:52:39 -08:00
|
|
|
y = y.feed(Flatten())
|
2017-03-12 17:41:18 -07:00
|
|
|
for i in range(n_dense):
|
|
|
|
if i > 0:
|
2017-04-11 03:32:48 -07:00
|
|
|
y = regulate(y)
|
2017-03-12 17:41:18 -07:00
|
|
|
y = y.feed(activation())
|
2017-04-10 21:48:53 -07:00
|
|
|
y = y.feed(Dense(y.output_shape[0], init=init_he_normal,
|
|
|
|
reg_w=reg, reg_b=reg))
|
2017-04-11 03:32:48 -07:00
|
|
|
y = regulate(y)
|
2017-03-12 17:41:18 -07:00
|
|
|
y = y.feed(activation())
|
2017-02-27 14:52:39 -08:00
|
|
|
|
2017-04-10 21:48:53 -07:00
|
|
|
y = y.feed(Dense(mnist_classes, init=init_glorot_uniform,
|
|
|
|
reg_w=final_reg, reg_b=final_reg))
|
2017-07-22 21:23:57 -07:00
|
|
|
y = y.feed(output_activation())
|
2017-02-27 14:52:39 -08:00
|
|
|
|
2017-09-16 09:59:36 -07:00
|
|
|
if output_activation in (Softmax, Sigmoid):
|
|
|
|
loss = CategoricalCrossentropy()
|
|
|
|
else:
|
|
|
|
loss = SquaredHalved()
|
|
|
|
mloss = Accuracy()
|
|
|
|
|
|
|
|
model = Model(x, y, loss=loss, mloss=mloss, unsafe=True)
|
2017-02-27 14:52:39 -08:00
|
|
|
|
2017-08-02 20:35:02 -07:00
|
|
|
def rscb(restart):
|
|
|
|
log("restarting", restart)
|
|
|
|
if restart_optim:
|
|
|
|
optim.reset()
|
|
|
|
|
2017-04-11 05:49:49 -07:00
|
|
|
if learner_class == SGDR:
|
|
|
|
learner = learner_class(optim, epochs=epochs//starts, rate=lr,
|
|
|
|
restarts=starts-1, restart_decay=restart_decay,
|
2017-08-02 20:35:02 -07:00
|
|
|
expando=lambda i:0,
|
|
|
|
callback=rscb)
|
2017-07-01 19:52:07 -07:00
|
|
|
elif learner_class in (TriangularCLR, SineCLR, WaveCLR):
|
2017-04-11 05:49:49 -07:00
|
|
|
learner = learner_class(optim, epochs=epochs, lower_rate=0, upper_rate=lr,
|
2017-08-02 20:35:02 -07:00
|
|
|
frequency=epochs//starts,
|
|
|
|
callback=rscb)
|
2017-07-22 21:23:57 -07:00
|
|
|
elif learner_class is AnnealingLearner:
|
|
|
|
learner = learner_class(optim, epochs=epochs, rate=lr,
|
|
|
|
halve_every=epochs//starts)
|
|
|
|
elif learner_class is DumbLearner:
|
|
|
|
learner = learner_class(self, optim, epochs=epochs//starts, rate=lr,
|
|
|
|
halve_every=epochs//(2*starts),
|
2017-08-02 20:35:02 -07:00
|
|
|
restarts=starts-1, restart_advance=epochs//starts,
|
|
|
|
callback=rscb)
|
2017-07-22 21:23:57 -07:00
|
|
|
elif learner_class is Learner:
|
|
|
|
learner = Learner(optim, epochs=epochs, rate=lr)
|
2017-07-01 19:52:07 -07:00
|
|
|
else:
|
2017-07-22 21:23:57 -07:00
|
|
|
if not isinstance(optim, YellowFin):
|
|
|
|
lament('WARNING: no learning rate schedule selected.')
|
2017-07-01 19:52:07 -07:00
|
|
|
learner = Learner(optim, epochs=epochs)
|
2017-02-27 14:52:39 -08:00
|
|
|
|
2017-09-16 09:59:36 -07:00
|
|
|
ritual = Ritual(learner=learner)
|
2017-02-27 14:52:39 -08:00
|
|
|
|
2017-06-30 19:17:46 -07:00
|
|
|
model.print_graph()
|
2017-02-27 14:52:39 -08:00
|
|
|
log('parameters', model.param_count)
|
|
|
|
|
|
|
|
ritual.prepare(model)
|
|
|
|
|
2017-07-01 19:53:31 -07:00
|
|
|
logs = DotMap(
|
|
|
|
batch_losses = [],
|
|
|
|
batch_mlosses = [],
|
|
|
|
train_losses = [],
|
|
|
|
train_mlosses = [],
|
|
|
|
valid_losses = [],
|
|
|
|
valid_mlosses = [],
|
|
|
|
learning_rate = [],
|
|
|
|
momentum = [],
|
|
|
|
)
|
2017-03-12 17:41:18 -07:00
|
|
|
|
2017-02-27 16:36:04 -08:00
|
|
|
def measure_error(quiet=False):
|
2017-08-05 03:39:32 -07:00
|
|
|
def print_error(name, inputs, outputs):
|
2017-02-27 14:52:39 -08:00
|
|
|
loss, mloss, _, _ = ritual.test_batched(inputs, outputs, bs, return_losses='both')
|
2017-03-12 17:41:18 -07:00
|
|
|
|
2017-02-27 16:36:04 -08:00
|
|
|
if not quiet:
|
|
|
|
log(name + " loss", "{:12.6e}".format(loss))
|
|
|
|
log(name + " accuracy", "{:6.2f}%".format(mloss * 100))
|
2017-03-12 17:41:18 -07:00
|
|
|
|
2017-08-02 20:35:02 -07:00
|
|
|
return loss, mloss
|
2017-02-27 14:52:39 -08:00
|
|
|
|
2017-08-02 20:35:02 -07:00
|
|
|
loss, mloss = print_error("train", inputs, outputs)
|
2017-07-01 19:53:31 -07:00
|
|
|
logs.train_losses.append(loss)
|
|
|
|
logs.train_mlosses.append(mloss)
|
2017-08-02 20:35:02 -07:00
|
|
|
loss, mloss = print_error("valid", valid_inputs, valid_outputs)
|
2017-07-01 19:53:31 -07:00
|
|
|
logs.valid_losses.append(loss)
|
|
|
|
logs.valid_mlosses.append(mloss)
|
2017-02-27 14:52:39 -08:00
|
|
|
|
|
|
|
measure_error()
|
|
|
|
|
|
|
|
while learner.next():
|
2017-04-10 21:48:53 -07:00
|
|
|
act_t = (learner.epoch - 1) / (learner.epochs - 1)
|
|
|
|
if actreg_lamb:
|
|
|
|
for node in model.ordered_nodes:
|
|
|
|
if isinstance(node, ActivityRegularizer):
|
|
|
|
node.reg.lamb = act_t * node.reg.lamb_orig # HACK
|
|
|
|
|
2017-02-27 16:36:04 -08:00
|
|
|
avg_loss, avg_mloss, losses, mlosses = ritual.train_batched(
|
2017-06-17 16:41:02 -07:00
|
|
|
inputs, outputs,
|
2017-02-27 14:52:39 -08:00
|
|
|
batch_size=bs,
|
|
|
|
return_losses='both')
|
|
|
|
fmt = "rate {:10.8f}, loss {:12.6e}, accuracy {:6.2f}%"
|
2017-03-22 14:41:24 -07:00
|
|
|
log("epoch {}".format(learner.epoch),
|
2017-02-27 14:52:39 -08:00
|
|
|
fmt.format(learner.rate, avg_loss, avg_mloss * 100))
|
|
|
|
|
2017-07-01 19:53:31 -07:00
|
|
|
logs.batch_losses += losses
|
|
|
|
logs.batch_mlosses += mlosses
|
2017-02-27 16:36:04 -08:00
|
|
|
|
|
|
|
if measure_every_epoch:
|
2017-03-22 14:41:24 -07:00
|
|
|
quiet = learner.epoch != learner.epochs
|
2017-02-27 16:36:04 -08:00
|
|
|
measure_error(quiet=quiet)
|
|
|
|
|
2017-07-01 22:39:51 -07:00
|
|
|
logs.learning_rate.append(optim.lr)
|
2017-07-01 19:53:31 -07:00
|
|
|
if getattr(optim, 'mu', None):
|
|
|
|
logs.momentum.append(optim.mu)
|
|
|
|
|
2017-02-27 16:36:04 -08:00
|
|
|
if not measure_every_epoch:
|
|
|
|
measure_error()
|
|
|
|
|
2017-04-10 21:48:53 -07:00
|
|
|
if save_fn is not None:
|
|
|
|
log('saving weights', save_fn)
|
|
|
|
model.save_weights(save_fn, overwrite=True)
|
|
|
|
|
2017-02-27 16:36:04 -08:00
|
|
|
if log_fn:
|
|
|
|
log('saving losses', log_fn)
|
2017-07-01 19:53:31 -07:00
|
|
|
kwargs = dict()
|
|
|
|
for k, v in logs.items():
|
|
|
|
if len(v) > 0:
|
|
|
|
kwargs[k] = np.array(v, dtype=_f)
|
|
|
|
np.savez_compressed(log_fn, **kwargs)
|