From 6fe25b1157e179897f5ee4ca45a00566faec093a Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Sat, 24 Mar 2018 11:04:21 +0100 Subject: [PATCH] add Exception subclasses and basic logging --- mnists/__init__.py | 37 +++++++++++++++++++++++-------------- mnists/__main__.py | 7 ++++++- mnists/exceptions.py | 18 ++++++++++++++++++ 3 files changed, 47 insertions(+), 15 deletions(-) create mode 100644 mnists/exceptions.py diff --git a/mnists/__init__.py b/mnists/__init__.py index 67ea989..091dffc 100644 --- a/mnists/__init__.py +++ b/mnists/__init__.py @@ -10,6 +10,7 @@ __version__ = "0.2.1" import array import gzip import hashlib +import logging import numpy as np import os import os.path @@ -18,17 +19,16 @@ import sys from urllib.request import urlretrieve from .hashes import hashes +from .exceptions import * +logger = logging.getLogger(__name__) + home = os.path.expanduser("~") output_directory = os.path.join(home, ".mnist") webhost = "https://eaguru.guru/mnist/" -def lament(*args, **kwargs): - print(*args, file=sys.stderr, **kwargs) - - def make_meta(npz, train_part="train", test_part="t10k", prefix=""): images_suffix = "-images-idx3-ubyte.gz" labels_suffix = "-labels-idx1-ubyte.gz" @@ -68,6 +68,7 @@ def make_directories(): directories += [construct_path(meta[0]) for meta in metadata.values()] for directory in directories: if not os.path.isdir(directory): + logger.info("Creating directory %s", directory) os.mkdir(directory) @@ -75,19 +76,21 @@ def download(name): url_name = "/".join(os.path.split(name)) path = construct_path(name) already_exists = os.path.isfile(path) - if not already_exists: + if already_exists: + logger.debug("Not downloading %s, it already exists: %s", name, path) + else: url = webhost + url_name try: urlretrieve(url, path) except Exception: - lament(f"Failed to download {url} to {path}") + logger.warning("Failed to download %s to %s", url, path) raise return already_exists def validate(name): if name not in hashes.keys(): - raise Exception(f"Unknown mnist dataset: {name}") + raise UnknownDatasetError(name) with open(construct_path(name), "rb") as f: data = f.read() @@ -96,9 +99,7 @@ def validate(name): hash = hashlib.sha256(data).hexdigest() if hash != known_hash: - raise Exception(f"""Failed to validate dataset: {name} -Hash mismatch: {hash} should be {known_hash} -Please check your local file for tampering or corruption.""") + raise IntegrityError(file, known_hash, hash) def onehot(ind): @@ -113,9 +114,13 @@ def read_array(f): return np.array(array.array("B", f.read()), dtype=np.uint8) -def open_maybe_gzip(path): - opener = gzip.open if path.endswith(".gz") else open - return opener(path, "rb") +def open_maybe_gzip(path, flags="rb"): + if path.endswith(".gz"): + logger.debug("Opening %s with gzip.open", path) + return gzip.open(path, flags) + else: + logger.debug("Opening %s with builtin open", path) + return open(path, flags) def load(images_path, labels_path): @@ -150,7 +155,7 @@ def batch_flatten(x): def prepare(dataset="mnist", return_floats=True, return_onehot=True, flatten=False, squeeze_labels=True, check_integrity=True): if dataset not in metadata.keys(): - raise Exception(f"Unknown dataset: {dataset}") + raise UnknownDatasetError(dataset) meta = metadata[dataset] prefix, names = meta[0], meta[1:] @@ -158,6 +163,9 @@ def prepare(dataset="mnist", return_floats=True, return_onehot=True, npz, train_images, train_labels, test_images, test_labels = names images_and_labels = names[1:] + logger.debug("Filenames chosen for %s: %s, %s, %s, %s", + dataset, train_images, train_labels, test_images, test_labels) + make_directories() existing = [os.path.isfile(construct_path(name)) for name in names] @@ -169,6 +177,7 @@ def prepare(dataset="mnist", return_floats=True, return_onehot=True, validate(name) if npz_existing: + logger.info("Loading npz file %s", npz) with np.load(construct_path(npz)): train_images_data, train_labels_data, \ test_images_data, test_labels_data = \ diff --git a/mnists/__main__.py b/mnists/__main__.py index 4e13aee..a504d4a 100644 --- a/mnists/__main__.py +++ b/mnists/__main__.py @@ -1,5 +1,10 @@ -from . import metadata, prepare +import logging +from . import metadata, prepare, logger + + +logging.basicConfig() +logger.setLevel(logging.DEBUG) headers = ("subdirectory", "dataset", diff --git a/mnists/exceptions.py b/mnists/exceptions.py new file mode 100644 index 0000000..c5b2b57 --- /dev/null +++ b/mnists/exceptions.py @@ -0,0 +1,18 @@ +class IntegrityError(Exception): + def __init__(self, file, expected_hash, computed_hash): + self.file = file + self.expected_hash = expected_hash + self.computed_hash = computed_hash + + def __str__(self): + return f"""Failed to validate dataset: {name} +Hash mismatch: {self.computed_hash} should be {self.expected_hash} +Please check your local file for tampering or corruption.""" + + +class UnknownDatasetError(Exception): + def __init__(self, dataset): + self.dataset = dataset + + def __str__(self): + return f"Unknown mnist-like dataset: {dataset}"