diff --git a/onn_mnist.py b/onn_mnist.py index 1cff3b7..c321bfe 100755 --- a/onn_mnist.py +++ b/onn_mnist.py @@ -77,26 +77,6 @@ else: mnist_classes = 10 def get_mnist(fn='mnist.npz'): - import os.path - if fn == 'mnist.npz' and not os.path.exists(fn): - from keras.datasets import mnist - from keras.utils.np_utils import to_categorical - (X_train, y_train), (X_test, y_test) = mnist.load_data() - X_train = X_train.reshape(X_train.shape[0], 1, mnist_dim, mnist_dim) - X_test = X_test.reshape(X_test.shape[0], 1, mnist_dim, mnist_dim) - X_train = X_train.astype('float32') / 255 - X_test = X_test.astype('float32') / 255 - Y_train = to_categorical(y_train, mnist_classes) - Y_test = to_categorical(y_test, mnist_classes) - np.savez_compressed(fn, - X_train=X_train, - Y_train=Y_train, - X_test=X_test, - Y_test=Y_test) - lament("mnist successfully saved to", fn) - lament("please re-run this program to continue") - sys.exit(1) - with np.load(fn) as f: return f['X_train'], f['Y_train'], f['X_test'], f['Y_test']