mnists/mnists/__init__.py

210 lines
6.3 KiB
Python

#!/usr/bin/env python3
# mnists
# Copyright (C) 2018 Connor Olding
# Distributed under terms of the MIT license.
__version__ = "0.2.1"
# NOTE: npz functionality is incomplete.
import array
import gzip
import hashlib
import numpy as np
import os
import os.path
import struct
import sys
from urllib.request import urlretrieve
from .hashes import hashes
home = os.path.expanduser("~")
output_directory = os.path.join(home, ".mnist")
webhost = "https://eaguru.guru/mnist/"
def lament(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)
def make_meta(npz, train_part="train", test_part="t10k", prefix=""):
images_suffix = "-images-idx3-ubyte.gz"
labels_suffix = "-labels-idx1-ubyte.gz"
return (prefix, npz,
train_part + images_suffix,
train_part + labels_suffix,
test_part + images_suffix,
test_part + labels_suffix)
def make_emnist_meta(npz, name):
return make_meta(npz, name + "-train", name + "-test", prefix="emnist")
metadata = dict(
emnist_balanced=make_emnist_meta("emnist_balanced.npz", "emnist-balanced"),
emnist_byclass=make_emnist_meta("emnist_byclass.npz", "emnist-byclass"),
emnist_bymerge=make_emnist_meta("emnist_bymerge.npz", "emnist-bymerge"),
emnist_digits=make_emnist_meta("emnist_digits.npz", "emnist-digits"),
emnist_letters=make_emnist_meta("emnist_letters.npz", "emnist-letters"),
emnist_mnist=make_emnist_meta("emnist_mnist.npz", "emnist-mnist"),
fashion_mnist=make_meta("fashion_mnist.npz", prefix="fashion-mnist"),
mnist=make_meta("mnist.npz", prefix="mnist"),
)
# correct the path separators for the user's system.
hashes = dict((os.path.join(*k.split("/")), v) for k, v in hashes.items())
def construct_path(name):
return os.path.join(output_directory, name)
def make_directories():
directories = [output_directory]
directories += [construct_path(meta[0]) for meta in metadata.values()]
for directory in directories:
if not os.path.isdir(directory):
os.mkdir(directory)
def download(name):
url_name = "/".join(os.path.split(name))
path = construct_path(name)
already_exists = os.path.isfile(path)
if not already_exists:
url = webhost + url_name
try:
urlretrieve(url, path)
except Exception:
lament(f"Failed to download {url} to {path}")
raise
return already_exists
def validate(name):
if name not in hashes.keys():
raise Exception(f"Unknown mnist dataset: {name}")
with open(construct_path(name), "rb") as f:
data = f.read()
known_hash = hashes[name]
hash = hashlib.sha256(data).hexdigest()
if hash != known_hash:
raise Exception(f"""Failed to validate dataset: {name}
Hash mismatch: {hash} should be {known_hash}
Please check your local file for tampering or corruption.""")
def onehot(ind):
unique = np.unique(ind)
hot = np.zeros((len(ind), len(unique)), dtype=np.int8)
offsets = np.arange(len(ind)) * len(unique)
hot.flat[offsets + ind.flat] = 1
return hot
def read_array(f):
return np.array(array.array("B", f.read()), dtype=np.uint8)
def open_maybe_gzip(path):
opener = gzip.open if path.endswith(".gz") else open
return opener(path, "rb")
def load(images_path, labels_path):
# load labels first so we can determine how many images there are.
with open_maybe_gzip(labels_path) as f:
magic, num = struct.unpack(">II", f.read(8))
labels = read_array(f)
with open_maybe_gzip(images_path) as f:
magic, num, rows, cols = struct.unpack(">IIII", f.read(16))
images = read_array(f).reshape(len(labels), 1, rows, cols)
return images, labels
def squeeze(labels):
unique = np.unique(labels)
new_values = np.arange(len(unique))
if np.all(new_values == unique):
return labels
relabelled = np.zeros_like(labels)
for i in new_values:
relabelled[labels == unique[i]] = i
return relabelled
def batch_flatten(x):
return x.reshape(len(x), -1)
def prepare(dataset="mnist", return_floats=True, return_onehot=True,
flatten=False, squeeze_labels=True, check_integrity=True):
if dataset not in metadata.keys():
raise Exception(f"Unknown dataset: {dataset}")
meta = metadata[dataset]
prefix, names = meta[0], meta[1:]
names = [os.path.join(prefix, name) for name in names]
npz, train_images, train_labels, test_images, test_labels = names
images_and_labels = names[1:]
make_directories()
existing = [os.path.isfile(construct_path(name)) for name in names]
npz_existing, gz_existing = existing[0], existing[1:]
for name in images_and_labels:
download(name)
if check_integrity:
validate(name)
if npz_existing:
with np.load(construct_path(npz)):
train_images_data, train_labels_data, \
test_images_data, test_labels_data = \
f["train_images"], f["train_labels"], \
f["test_images"], f["test_labels"]
else:
train_images_data, train_labels_data = load(
construct_path(train_images), construct_path(train_labels))
test_images_data, test_labels_data = load(
construct_path(test_images), construct_path(test_labels))
# correct the orientation of emnist images.
if prefix == "emnist":
train_images_data = train_images_data.transpose(0, 1, 3, 2)
test_images_data = test_images_data.transpose(0, 1, 3, 2)
if return_floats: # TODO: better name.
train_images_data = train_images_data.astype(np.float32) / 255
test_images_data = test_images_data.astype(np.float32) / 255
# emnist_letters uses labels indexed from 1 instead of the usual 0.
# the onehot function assumes labels are contiguous from 0.
if squeeze_labels or return_onehot:
train_labels_data = squeeze(train_labels_data)
test_labels_data = squeeze(test_labels_data)
if return_onehot:
train_labels_data = onehot(train_labels_data)
test_labels_data = onehot(test_labels_data)
if flatten:
train_images_data = batch_flatten(train_images_data)
test_images_data = batch_flatten(test_images_data)
return train_images_data, train_labels_data, \
test_images_data, test_labels_data
# kanpeki!