gists/resnet/resnet.py

185 lines
5.4 KiB
Python
Raw Normal View History

2017-01-05 04:55:10 -08:00
#!/usr/bin/env python3
import keras.backend as K
assert K.image_dim_ordering() == 'th'
2016-08-09 01:11:05 -07:00
import pickle, time
import sys
import numpy as np
from keras.callbacks import LearningRateScheduler
from keras.datasets import mnist
from keras.layers import BatchNormalization
from keras.layers import Convolution2D, MaxPooling2D
from keras.layers import Flatten, Reshape
2017-01-05 04:55:10 -08:00
from keras.layers import Input, merge, Dense, Activation
2016-08-09 01:11:05 -07:00
from keras.models import Model
from keras.utils.np_utils import to_categorical
nb_classes = 10
width = 28
height = 28
loss='categorical_crossentropy'
name = 'resnet-{:.0f}'.format(time.time())
args = dict(enumerate(sys.argv))
restore_fn = args.get(1)
if restore_fn == '.': # TODO: accept any directory
# just use most recent resnet-*.pkl file in directory
import os
is_valid = lambda fn: fn.startswith('resnet-') and fn.endswith('.pkl')
files = sorted([fn for fn in os.listdir(restore_fn) if is_valid(fn)])
if len(files) == 0:
raise Exception("couldn't find any appropriate .pkl files in the CWD")
restore_fn = files[-1]
dont_train = False
verbose_summary = False
reslayers = 4
size = 8
batch_size = 128
epochs = 24
convolutional = True
resnet_enabled = True
original_resnet = False
LR = 1e-2
LRprod = 0.1**(1/20.) # will use a tenth of the learning rate after 20 epochs
use_image_generator = True
2017-01-05 04:55:10 -08:00
def prepare(X, y):
X = X.reshape(X.shape[0], 1, width, height).astype('float32') / 255
# convert class vectors to binary class matrices
2017-01-05 04:56:35 -08:00
Y = to_categorical(y, nb_classes)
2017-01-05 04:55:10 -08:00
return X, Y
2016-08-09 01:11:05 -07:00
# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()
2017-01-05 04:55:10 -08:00
X_train, Y_train = prepare(X_train, y_train)
X_test, Y_test = prepare(X_test, y_test)
2016-08-09 01:11:05 -07:00
if use_image_generator:
from keras.preprocessing.image import ImageDataGenerator
idg = ImageDataGenerator(rotation_range=5.,
width_shift_range=.10,
height_shift_range=.10,
shear_range=5 / 180 * np.pi,
zoom_range=0.1,
fill_mode='constant',
cval=0.)
# ReLU activation is supposed to be the best with he_normal
if convolutional:
layer = lambda x: Convolution2D(x, 3, 3, init='he_normal', border_mode='same')
else:
layer = lambda x: Dense(x, init='he_normal')
# start construting the model
x = Input(shape=(1, width, height))
y = x
if convolutional:
# it might be worth trying other sizes here
y = Convolution2D(size, 7, 7, subsample=(2, 2), border_mode='same')(y)
y = MaxPooling2D()(y)
else:
y = Flatten()(y)
y = Dense(dense_size)(y)
for i in range(reslayers):
skip = y
if original_resnet:
y = layer(size)(y)
y = BatchNormalization(axis=1)(y)
y = Activation('relu')(y)
y = layer(size)(y)
y = BatchNormalization(axis=1)(y)
if resnet_enabled: y = merge([skip, y], mode='sum')
y = Activation('relu')(y)
else:
y = BatchNormalization(axis=1)(y)
y = Activation('relu')(y)
y = layer(size)(y)
y = BatchNormalization(axis=1)(y)
y = Activation('relu')(y)
y = layer(size)(y)
if resnet_enabled: y = merge([skip, y], mode='sum')
if convolutional:
from keras.layers import AveragePooling1D
y = Reshape((size, int(width * height / 2**2 / 2**2)))(y)
y = AveragePooling1D(size)(y)
y = Flatten()(y)
y = Dense(nb_classes)(y)
y = Activation('softmax')(y)
model = Model(input=x, output=y)
if verbose_summary:
model.summary()
else:
total_params = 0
for layer in model.layers:
total_params += layer.count_params()
print("Total params: {}".format(total_params))
if restore_fn:
with open(restore_fn, 'rb') as f:
W = pickle.loads(f.read())
if not dont_train:
# sparsify an existing model
for i, w in enumerate(W):
if w.shape == (size, size, 3, 3):
middle = np.median(np.abs(w.flat))
where = np.abs(w) < middle
total = np.prod(w.shape)
fmt = 'W[{}]: zeroing {} params of {}'
print(fmt.format(i, int(np.count_nonzero(where)), int(total)))
W[i] = np.where(where, 0, w)
model.set_weights(W)
LR /= 10
model.compile(loss=loss, optimizer='adam', metrics=['accuracy'])
if not dont_train:
callbacks = [LearningRateScheduler(lambda e: LR * LRprod**e)]
kwargs = dict(
nb_epoch=epochs,
validation_data=(X_test, Y_test),
callbacks=callbacks,
verbose=1
)
if use_image_generator:
history = model.fit_generator(idg.flow(X_train, Y_train, batch_size=batch_size),
samples_per_epoch=len(X_train), **kwargs)
else:
history = model.fit(X_train, Y_train, batch_size=batch_size,
**kwargs)
def evaluate(X, Y):
score = model.evaluate(X, Y, verbose=0)
for name, score in zip(model.metrics_names, score):
if name == "acc":
print("{:7} {:6.2f}%".format(name, score * 100))
else:
print("{:7} {:7.5f}".format(name, score))
print('TRAIN')
evaluate(X_train, Y_train)
print('TEST')
evaluate(X_test, Y_test)
print('ALL')
evaluate(np.vstack((X_train, X_test)), np.vstack((Y_train, Y_test)))
if not dont_train:
open(name+'.json', 'w').write(model.to_json())
with open(name+'.pkl', 'wb') as f:
f.write(pickle.dumps(model.get_weights()))