#!/usr/bin/env python3 # mnists # Copyright (C) 2018 Connor Olding # Distributed under terms of the MIT license. __version__ = "0.3.2" import array import gzip import hashlib import logging import numpy as np import os import os.path import struct import sys from urllib.request import urlretrieve from .hashes import hashes from .exceptions import * logger = logging.getLogger(__name__) home = os.path.expanduser("~") output_directory = os.path.join(home, ".mnist") webhost = "https://eaguru.guru/mnist/" def _make_meta(prefix, train_part="train", test_part="t10k"): images_suffix = "-images-idx3-ubyte.gz" labels_suffix = "-labels-idx1-ubyte.gz" return (prefix, train_part + images_suffix, train_part + labels_suffix, test_part + images_suffix, test_part + labels_suffix) def _make_meta2(name): prefix, _, _ = name.partition("-") return _make_meta(prefix, name + "-train", name + "-test") metadata = dict( emnist_balanced=_make_meta2("emnist-balanced"), emnist_byclass=_make_meta2("emnist-byclass"), emnist_bymerge=_make_meta2("emnist-bymerge"), emnist_digits=_make_meta2("emnist-digits"), emnist_letters=_make_meta2("emnist-letters"), emnist_mnist=_make_meta2("emnist-mnist"), fashion_mnist=_make_meta("fashion-mnist"), mnist=_make_meta("mnist"), qmnist=_make_meta2("qmnist"), ) # correct the path separators for the user's system. hashes = dict((os.path.join(*k.split("/")), v) for k, v in hashes.items()) def construct_path(name): return os.path.join(output_directory, name) def make_directories(): directories = [output_directory] directories += [construct_path(meta[0]) for meta in metadata.values()] for directory in directories: if not os.path.isdir(directory): logger.info("Creating directory %s", directory) os.mkdir(directory) def download(name): url_name = "/".join(os.path.split(name)) path = construct_path(name) already_exists = os.path.isfile(path) if already_exists: logger.debug("Not downloading %s, it already exists: %s", name, path) else: url = webhost + url_name try: urlretrieve(url, path) except Exception: logger.warning("Failed to download %s to %s", url, path) raise return already_exists def validate(name): if name not in hashes.keys(): raise UnknownDatasetError(name) path = construct_path(name) with open(path, "rb") as f: data = f.read() known_hash = hashes[name] hash = hashlib.sha256(data).hexdigest() if hash != known_hash: raise IntegrityError(path, known_hash, hash) def onehot(ind): unique = np.unique(ind) hot = np.zeros((len(ind), len(unique)), dtype=np.int8) offsets = np.arange(len(ind)) * len(unique) hot.flat[offsets + ind.flat] = 1 return hot def read_array(f): return np.array(array.array("B", f.read()), dtype=np.uint8) def open_maybe_gzip(path, flags="rb"): if path.endswith(".gz"): logger.debug("Opening %s with gzip.open", path) return gzip.open(path, flags) else: logger.debug("Opening %s with builtin open", path) return open(path, flags) def load(images_path, labels_path): # load labels first so we can determine how many images there are. with open_maybe_gzip(labels_path) as f: magic, num = struct.unpack(">II", f.read(8)) labels = read_array(f) with open_maybe_gzip(images_path) as f: magic, num, rows, cols = struct.unpack(">IIII", f.read(16)) images = read_array(f).reshape(len(labels), 1, rows, cols) return images, labels def squeeze(labels): unique = np.unique(labels) new_values = np.arange(len(unique)) if np.all(new_values == unique): return labels relabelled = np.zeros_like(labels) for i in new_values: relabelled[labels == unique[i]] = i return relabelled def batch_flatten(x): return x.reshape(len(x), -1) def prepare(dataset="mnist", return_floats=True, return_onehot=True, flatten=False, squeeze_labels=True, check_integrity=True): if dataset not in metadata.keys(): raise UnknownDatasetError(dataset) meta = metadata[dataset] prefix, names = meta[0], meta[1:] names = [os.path.join(prefix, name) for name in names] train_images, train_labels, test_images, test_labels = names logger.debug("Filenames chosen for %s: %s, %s, %s, %s", dataset, train_images, train_labels, test_images, test_labels) make_directories() for name in names: download(name) if check_integrity: validate(name) train_images_data, train_labels_data = load( construct_path(train_images), construct_path(train_labels)) test_images_data, test_labels_data = load( construct_path(test_images), construct_path(test_labels)) # correct the orientation of emnist images. if prefix == "emnist": train_images_data = train_images_data.transpose(0, 1, 3, 2) test_images_data = test_images_data.transpose(0, 1, 3, 2) if return_floats: # TODO: better name. train_images_data = train_images_data.astype(np.float32) / 255 test_images_data = test_images_data.astype(np.float32) / 255 # emnist_letters uses labels indexed from 1 instead of the usual 0. # the onehot function assumes labels are contiguous from 0. if squeeze_labels or return_onehot: train_labels_data = squeeze(train_labels_data) test_labels_data = squeeze(test_labels_data) if return_onehot: train_labels_data = onehot(train_labels_data) test_labels_data = onehot(test_labels_data) if flatten: train_images_data = batch_flatten(train_images_data) test_images_data = batch_flatten(test_images_data) return train_images_data, train_labels_data, \ test_images_data, test_labels_data # kanpeki!