add Exception subclasses and basic logging

This commit is contained in:
Connor Olding 2018-03-24 11:04:21 +01:00
parent cbc9c4d7e4
commit 6fe25b1157
3 changed files with 47 additions and 15 deletions

View File

@ -10,6 +10,7 @@ __version__ = "0.2.1"
import array import array
import gzip import gzip
import hashlib import hashlib
import logging
import numpy as np import numpy as np
import os import os
import os.path import os.path
@ -18,17 +19,16 @@ import sys
from urllib.request import urlretrieve from urllib.request import urlretrieve
from .hashes import hashes from .hashes import hashes
from .exceptions import *
logger = logging.getLogger(__name__)
home = os.path.expanduser("~") home = os.path.expanduser("~")
output_directory = os.path.join(home, ".mnist") output_directory = os.path.join(home, ".mnist")
webhost = "https://eaguru.guru/mnist/" webhost = "https://eaguru.guru/mnist/"
def lament(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)
def make_meta(npz, train_part="train", test_part="t10k", prefix=""): def make_meta(npz, train_part="train", test_part="t10k", prefix=""):
images_suffix = "-images-idx3-ubyte.gz" images_suffix = "-images-idx3-ubyte.gz"
labels_suffix = "-labels-idx1-ubyte.gz" labels_suffix = "-labels-idx1-ubyte.gz"
@ -68,6 +68,7 @@ def make_directories():
directories += [construct_path(meta[0]) for meta in metadata.values()] directories += [construct_path(meta[0]) for meta in metadata.values()]
for directory in directories: for directory in directories:
if not os.path.isdir(directory): if not os.path.isdir(directory):
logger.info("Creating directory %s", directory)
os.mkdir(directory) os.mkdir(directory)
@ -75,19 +76,21 @@ def download(name):
url_name = "/".join(os.path.split(name)) url_name = "/".join(os.path.split(name))
path = construct_path(name) path = construct_path(name)
already_exists = os.path.isfile(path) already_exists = os.path.isfile(path)
if not already_exists: if already_exists:
logger.debug("Not downloading %s, it already exists: %s", name, path)
else:
url = webhost + url_name url = webhost + url_name
try: try:
urlretrieve(url, path) urlretrieve(url, path)
except Exception: except Exception:
lament(f"Failed to download {url} to {path}") logger.warning("Failed to download %s to %s", url, path)
raise raise
return already_exists return already_exists
def validate(name): def validate(name):
if name not in hashes.keys(): if name not in hashes.keys():
raise Exception(f"Unknown mnist dataset: {name}") raise UnknownDatasetError(name)
with open(construct_path(name), "rb") as f: with open(construct_path(name), "rb") as f:
data = f.read() data = f.read()
@ -96,9 +99,7 @@ def validate(name):
hash = hashlib.sha256(data).hexdigest() hash = hashlib.sha256(data).hexdigest()
if hash != known_hash: if hash != known_hash:
raise Exception(f"""Failed to validate dataset: {name} raise IntegrityError(file, known_hash, hash)
Hash mismatch: {hash} should be {known_hash}
Please check your local file for tampering or corruption.""")
def onehot(ind): def onehot(ind):
@ -113,9 +114,13 @@ def read_array(f):
return np.array(array.array("B", f.read()), dtype=np.uint8) return np.array(array.array("B", f.read()), dtype=np.uint8)
def open_maybe_gzip(path): def open_maybe_gzip(path, flags="rb"):
opener = gzip.open if path.endswith(".gz") else open if path.endswith(".gz"):
return opener(path, "rb") logger.debug("Opening %s with gzip.open", path)
return gzip.open(path, flags)
else:
logger.debug("Opening %s with builtin open", path)
return open(path, flags)
def load(images_path, labels_path): def load(images_path, labels_path):
@ -150,7 +155,7 @@ def batch_flatten(x):
def prepare(dataset="mnist", return_floats=True, return_onehot=True, def prepare(dataset="mnist", return_floats=True, return_onehot=True,
flatten=False, squeeze_labels=True, check_integrity=True): flatten=False, squeeze_labels=True, check_integrity=True):
if dataset not in metadata.keys(): if dataset not in metadata.keys():
raise Exception(f"Unknown dataset: {dataset}") raise UnknownDatasetError(dataset)
meta = metadata[dataset] meta = metadata[dataset]
prefix, names = meta[0], meta[1:] prefix, names = meta[0], meta[1:]
@ -158,6 +163,9 @@ def prepare(dataset="mnist", return_floats=True, return_onehot=True,
npz, train_images, train_labels, test_images, test_labels = names npz, train_images, train_labels, test_images, test_labels = names
images_and_labels = names[1:] images_and_labels = names[1:]
logger.debug("Filenames chosen for %s: %s, %s, %s, %s",
dataset, train_images, train_labels, test_images, test_labels)
make_directories() make_directories()
existing = [os.path.isfile(construct_path(name)) for name in names] existing = [os.path.isfile(construct_path(name)) for name in names]
@ -169,6 +177,7 @@ def prepare(dataset="mnist", return_floats=True, return_onehot=True,
validate(name) validate(name)
if npz_existing: if npz_existing:
logger.info("Loading npz file %s", npz)
with np.load(construct_path(npz)): with np.load(construct_path(npz)):
train_images_data, train_labels_data, \ train_images_data, train_labels_data, \
test_images_data, test_labels_data = \ test_images_data, test_labels_data = \

View File

@ -1,5 +1,10 @@
from . import metadata, prepare import logging
from . import metadata, prepare, logger
logging.basicConfig()
logger.setLevel(logging.DEBUG)
headers = ("subdirectory", headers = ("subdirectory",
"dataset", "dataset",

18
mnists/exceptions.py Normal file
View File

@ -0,0 +1,18 @@
class IntegrityError(Exception):
def __init__(self, file, expected_hash, computed_hash):
self.file = file
self.expected_hash = expected_hash
self.computed_hash = computed_hash
def __str__(self):
return f"""Failed to validate dataset: {name}
Hash mismatch: {self.computed_hash} should be {self.expected_hash}
Please check your local file for tampering or corruption."""
class UnknownDatasetError(Exception):
def __init__(self, dataset):
self.dataset = dataset
def __str__(self):
return f"Unknown mnist-like dataset: {dataset}"