commit c65a3a64aa8fb0ab55f0efa01f504ca417322c22 Author: Connor Olding Date: Wed Mar 14 16:16:14 2018 +0100 init diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..2e58a33 --- /dev/null +++ b/LICENSE @@ -0,0 +1,22 @@ +Copyright (C) 2018 Connor Olding + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, +copy, modify, merge, publish, distribute, sublicense, and/or +sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..e61cf65 --- /dev/null +++ b/README.md @@ -0,0 +1,50 @@ +# mnists + +downloads and prepares various mnist-compatible datasets. + +files are downloaded to `~/.mnist` and checked for integrity by sha256 hashes. + +**dependencies:** numpy + +**install:** `pip install --upgrade --upgrade-strategy only-if-needed mnists` + +I've added --upgrade-strategy to the command-line +so you don't accidentally "upgrade" numpy to +a version not compiled specifically for your system. + +## usage + +```python +import mnists + +dataset = "emnist_balanced" +train_images, train_labels, test_images, test_labels = mnists.prepare(dataset) +``` + +the default output shape is (n, 1, 28, 28). +pass `flatten=True` to `mnists.prepare` to get (n, 784). + +## datasets + +in alphabetical order: + +### [emnist][emnist] + +* `emnist_balanced` +* `emnist_byclass` +* `emnist_bymerge` +* `emnist_digits` +* `emnist_letters` +* `emnist_mnist` + +### [fashion-mnist][fashion-mnist] + +* `fashion_mnist` + +### [mnist][mnist] + +* `mnist` + +[emnist]: //www.nist.gov/itl/iad/image-group/emnist-dataset +[fashion-mnist]: //github.com/zalandoresearch/fashion-mnist +[mnist]: http://yann.lecun.com/exdb/mnist/ diff --git a/TODO b/TODO new file mode 100644 index 0000000..ab852ed --- /dev/null +++ b/TODO @@ -0,0 +1,21 @@ +TODO + + * abide to PEP 8 + + * finish writing README + + * finish npz functionality + + * document prepare() function + + * support python 3.5 + + * adjust dates created/modified on server-hosted files to something sensible + + * basic tests (including PEP 8) + + * submit to pypi + + * document everything + + * test everything diff --git a/mnists/__init__.py b/mnists/__init__.py new file mode 100644 index 0000000..20fb634 --- /dev/null +++ b/mnists/__init__.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +# mnists +# Copyright (C) 2018 Connor Olding +# Distributed under terms of the MIT license. + +# NOTE: npz functionality is incomplete. + +import array +import gzip +import hashlib +import numpy as np +import os +import os.path +import struct +import sys +from urllib.request import urlretrieve + + +home = os.path.expanduser("~") +output_directory = os.path.join(home, ".mnist") +webhost = "https://eaguru.guru/mnist/" + + +def lament(*args, **kwargs): + print(*args, file=sys.stderr, **kwargs) + + +def make_meta(npz, train_part="train", test_part="t10k", prefix=""): + images_suffix = "-images-idx3-ubyte.gz" + labels_suffix = "-labels-idx1-ubyte.gz" + return (prefix, npz, + train_part + images_suffix, + train_part + labels_suffix, + test_part + images_suffix, + test_part + labels_suffix) + + +def make_emnist_meta(npz, name): + return make_meta(npz, name + "-train", name + "-test", prefix="emnist") + + +metadata = dict( + emnist_balanced = make_emnist_meta("emnist_balanced.npz", "emnist-balanced"), + emnist_byclass = make_emnist_meta("emnist_byclass.npz", "emnist-byclass"), + emnist_bymerge = make_emnist_meta("emnist_bymerge.npz", "emnist-bymerge"), + emnist_digits = make_emnist_meta("emnist_digits.npz", "emnist-digits"), + emnist_letters = make_emnist_meta("emnist_letters.npz", "emnist-letters"), + emnist_mnist = make_emnist_meta("emnist_mnist.npz", "emnist-mnist"), + fashion_mnist = make_meta("fashion_mnist.npz", prefix="fashion-mnist"), + mnist = make_meta("mnist.npz", prefix="mnist"), +) + + +hash_data = """ +emnist/emnist-balanced-test-images-idx3-ubyte.gz 117fc12b015e37ea98bfb2cbba3f77c3782c056cdbd219b87abfc8661d7135db +emnist/emnist-balanced-test-labels-idx1-ubyte.gz a67cac948379b67270b9a4d0a01c8ec661598accf2f2b1964340b7c83cb175cd +emnist/emnist-balanced-train-images-idx3-ubyte.gz 74295e75e4e9fa2c102147460ee7eb763653d5132d883a84fb88932bfb8d2db4 +emnist/emnist-balanced-train-labels-idx1-ubyte.gz b46d41e5156a0a8aaaa4d18979fdd1588850fa8571018acffd5d93b325a7d6ed +emnist/emnist-byclass-test-images-idx3-ubyte.gz 5ec835bd5d5b1a66378a132e727fef20fe994b3cc666899befcca5ac766f50cd +emnist/emnist-byclass-test-labels-idx1-ubyte.gz 501e305eeeacb093b162592d80e8a2632040872360fa6d85cfcb2200d91ca92a +emnist/emnist-byclass-train-images-idx3-ubyte.gz 9e21987f03a6eb8b870187e785883ddad96947299587908f1a7ec2ca0f4be5f0 +emnist/emnist-byclass-train-labels-idx1-ubyte.gz f7ef403d48dc18a89dae0d83d11ee69cecb6e1c12c3a961d03712e495c02a8fa +emnist/emnist-bymerge-test-images-idx3-ubyte.gz 5e5815ec1b522089134435f7c55e75a00960b58414793593ffff23efc0821683 +emnist/emnist-bymerge-test-labels-idx1-ubyte.gz 8bda1516255683ba9071e141947ca0c91a911176065cdd6259b0b4d8dc5b5f35 +emnist/emnist-bymerge-train-images-idx3-ubyte.gz 39bf8c68037dad40db35348895f7e4e056e8c0a005ab78869c9fbc2c96e947be +emnist/emnist-bymerge-train-labels-idx1-ubyte.gz 2c47293b9e62d6abf6ad05fbf67e40baa32adfd380b8d7b44dcc35534b64e1f1 +emnist/emnist-digits-test-images-idx3-ubyte.gz 20ebe43509264f7639d37c835a3f0b009e90b29e00a099a03101abe1644a3f63 +emnist/emnist-digits-test-labels-idx1-ubyte.gz bad1834c45d5988e270b99d43b13edc3fd898f658ffbc2caf2846ef27f7b05b0 +emnist/emnist-digits-train-images-idx3-ubyte.gz f7c95004d14d81af89522e67d8f0c781dfb3bd544181c81bfe1eb6b41c35b726 +emnist/emnist-digits-train-labels-idx1-ubyte.gz 6638a6ff5fe2eefd9cca2471995b46c82986b8e4b2a70d2f5ce8e05d7edb319e +emnist/emnist-letters-test-images-idx3-ubyte.gz 50f0e93d75b99b463e9760b8345866f22587fa2c2fba472730a5ba09c693412d +emnist/emnist-letters-test-labels-idx1-ubyte.gz 924955c31fe7fd809b7303b5c0141282955f4b8390c74cdfd5d80435e0f4b4cb +emnist/emnist-letters-train-images-idx3-ubyte.gz 6288f418a917e007bf5fddb823a2dc7a5c6a35031401b93292a539cb12bd881f +emnist/emnist-letters-train-labels-idx1-ubyte.gz 9759af91ea6bccdf07a6040fd085861f2ec88e37f7c3f5760f6feb57106fbe5c +emnist/emnist-mnist-test-images-idx3-ubyte.gz b25640fa9e356c618132f5df347c95ef5bae6d9c2596be1de40f7e0ae1c3eac6 +emnist/emnist-mnist-test-labels-idx1-ubyte.gz bcc5385d38f083dba046f5f756e99b32f5ea7fad122c6243c71c2239873065ca +emnist/emnist-mnist-train-images-idx3-ubyte.gz 5f9942f441031c0f1f2e4d162059fd6e19ea808b34c328b6d16cab0ed24c78f2 +emnist/emnist-mnist-train-labels-idx1-ubyte.gz 00c6d3c2342fdffb1711e0b5656cac4ce5862d2826399d383a9bac07d612e9f8 +fashion-mnist/t10k-images-idx3-ubyte.gz 78dcbbfb5a27efaf1b6c1d616caca68560be8766c5bffcb2791df9273a534229 +fashion-mnist/t10k-labels-idx1-ubyte.gz 42bd18137a62d5998cdeae52bf3a0676ac6b706f5cf8439b47bb5b151ae3dccf +fashion-mnist/train-images-idx3-ubyte.gz 08dc20ab1689590a0bcd66e874647b6ee8464e2d501b5a3f1f78831db19a3fdc +fashion-mnist/train-labels-idx1-ubyte.gz b0197879cbda89f3dc7b894f9fd52b858e68ea4182b6947c9d8c2b67e5f18dcc +mnist/t10k-images-idx3-ubyte.gz 8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6 +mnist/t10k-labels-idx1-ubyte.gz f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6 +mnist/train-images-idx3-ubyte.gz 440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609 +mnist/train-labels-idx1-ubyte.gz 3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c +""" + +hashes = dict(hash.split(" ") for hash in hash_data.strip().split("\n")) + +# correct the path separators for the user's system. +hashes = dict((os.path.join(*k.split("/")), v) for k, v in hashes.items()) + + +def construct_path(name): + return os.path.join(output_directory, name) + + +def make_directories(): + directories = [output_directory] + directories += [construct_path(meta[0]) for meta in metadata.values()] + for directory in directories: + if not os.path.isdir(directory): + os.mkdir(directory) + + +def download(name): + url_name = "/".join(os.path.split(name)) + path = construct_path(name) + already_exists = os.path.isfile(path) + if not already_exists: + url = webhost + url_name + try: + urlretrieve(url, path) + except: + lament(f"Failed to download {url} to {path}") + raise + return already_exists + + +def validate(name): + if name not in hashes.keys(): + raise Exception(f"Unknown mnist dataset: {name}") + + with open(construct_path(name), "rb") as f: + data = f.read() + + known_hash = hashes[name] + hash = hashlib.sha256(data).hexdigest() + + if hash != known_hash: + raise Exception(f"""Failed to validate dataset: {name} +Hash mismatch: {hash} should be {known_hash} +Please check your local file for tampering or corruption.""") + + +def onehot(ind): + unique = np.unique(ind) + hot = np.zeros((len(ind), len(unique)), dtype=np.int8) + offsets = np.arange(len(ind)) * len(unique) + hot.flat[offsets + ind.flat] = 1 + return hot + + +def read_array(f): + return np.array(array.array("B", f.read()), dtype=np.uint8) + + +def open_maybe_gzip(path): + opener = gzip.open if path.endswith(".gz") else open + return opener(path, "rb") + + +def load(images_path, labels_path): + # load labels first so we can determine how many images there are. + with open_maybe_gzip(labels_path) as f: + magic, num = struct.unpack(">II", f.read(8)) + labels = read_array(f) + + with open_maybe_gzip(images_path) as f: + magic, num, rows, cols = struct.unpack(">IIII", f.read(16)) + images = read_array(f).reshape(len(labels), 1, rows, cols) + + return images, labels + + +def squeeze(labels): + unique = np.unique(labels) + new_values = np.arange(len(unique)) + if np.all(new_values == unique): + return labels + + relabelled = np.zeros_like(labels) + for i in new_values: + relabelled[labels == unique[i]] = i + return relabelled + + +def batch_flatten(x): + return x.reshape(len(x), -1) + + +def prepare(dataset="mnist", return_floats=True, return_onehot=True, + flatten=False, squeeze_layers=True): + if dataset not in metadata.keys(): + raise Exception(f"Unknown dataset: {dataset}") + + meta = metadata[dataset] + prefix, names = meta[0], meta[1:] + names = [os.path.join(prefix, name) for name in names] + npz, train_images, train_labels, test_images, test_labels = names + images_and_labels = names[1:] + + make_directories() + + existing = [os.path.isfile(construct_path(name)) for name in names] + + npz_existing, gz_existing = existing[0], existing[1:] + + for name in images_and_labels: + download(name) + validate(name) + + if npz_existing: + with np.load(construct_path(npz)): + train_images_data, train_labels_data, \ + test_images_data, test_labels_data = \ + f["train_images"], f["train_labels"], \ + f["test_images"], f["test_labels"] + else: + train_images_data, train_labels_data = load( + construct_path(train_images), construct_path(train_labels)) + test_images_data, test_labels_data = load( + construct_path(test_images), construct_path(test_labels)) + + # correct the orientation of emnist images. + if prefix == "emnist": + train_images_data = train_images_data.transpose(0,1,3,2) + test_images_data = test_images_data.transpose(0,1,3,2) + + if return_floats: # TODO: better name. + train_images_data = train_images_data.astype(np.float32) / np.float32(255) + test_images_data = test_images_data.astype(np.float32) / np.float32(255) + assert train_images_data.dtype == 'float32' + assert test_images_data.dtype == 'float32' + + # emnist_letters uses labels indexed from 1 instead of the usual 0. + # the onehot function assumes labels are contiguous from 0. + if squeeze_layers or return_onehot: + train_labels_data = squeeze(train_labels_data) + test_labels_data = squeeze(test_labels_data) + + if return_onehot: + train_labels_data = onehot(train_labels_data) + test_labels_data = onehot(test_labels_data) + + if flatten: + train_images_data = batch_flatten(train_images_data) + test_images_data = batch_flatten(test_images_data) + + return train_images_data, train_labels_data, \ + test_images_data, test_labels_data + + # kanpeki! diff --git a/mnists/__main__.py b/mnists/__main__.py new file mode 100644 index 0000000..e64dbce --- /dev/null +++ b/mnists/__main__.py @@ -0,0 +1,6 @@ +from . import metadata, prepare + + +for name in metadata.keys(): + print(name) + prepare(name) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..24ce15a --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +numpy diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..859c6fa --- /dev/null +++ b/setup.py @@ -0,0 +1,28 @@ +from setuptools import setup + +setup( + name='mnists', + version='0.1.0', + packages=[ + 'mnists', + ], + + author='notwa', + author_email='cloningdonor+pypi@gmail.com', + url='https://github.com/notwa/mnists', + keywords='TODO', + description='downloads and prepares various mnist-compatible datasets.', + license='MIT', + zip_safe=True, + + classifiers=[ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: MIT License', + 'Natural Language :: English', + 'Programming Language :: Python', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.6', + 'Topic :: Scientific/Engineering', + ] +)