add Exception subclasses and basic logging

This commit is contained in:
Connor Olding 2018-03-24 11:04:21 +01:00
parent cbc9c4d7e4
commit 6fe25b1157
3 changed files with 47 additions and 15 deletions

View File

@ -10,6 +10,7 @@ __version__ = "0.2.1"
import array
import gzip
import hashlib
import logging
import numpy as np
import os
import os.path
@ -18,17 +19,16 @@ import sys
from urllib.request import urlretrieve
from .hashes import hashes
from .exceptions import *
logger = logging.getLogger(__name__)
home = os.path.expanduser("~")
output_directory = os.path.join(home, ".mnist")
webhost = "https://eaguru.guru/mnist/"
def lament(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)
def make_meta(npz, train_part="train", test_part="t10k", prefix=""):
images_suffix = "-images-idx3-ubyte.gz"
labels_suffix = "-labels-idx1-ubyte.gz"
@ -68,6 +68,7 @@ def make_directories():
directories += [construct_path(meta[0]) for meta in metadata.values()]
for directory in directories:
if not os.path.isdir(directory):
logger.info("Creating directory %s", directory)
os.mkdir(directory)
@ -75,19 +76,21 @@ def download(name):
url_name = "/".join(os.path.split(name))
path = construct_path(name)
already_exists = os.path.isfile(path)
if not already_exists:
if already_exists:
logger.debug("Not downloading %s, it already exists: %s", name, path)
else:
url = webhost + url_name
try:
urlretrieve(url, path)
except Exception:
lament(f"Failed to download {url} to {path}")
logger.warning("Failed to download %s to %s", url, path)
raise
return already_exists
def validate(name):
if name not in hashes.keys():
raise Exception(f"Unknown mnist dataset: {name}")
raise UnknownDatasetError(name)
with open(construct_path(name), "rb") as f:
data = f.read()
@ -96,9 +99,7 @@ def validate(name):
hash = hashlib.sha256(data).hexdigest()
if hash != known_hash:
raise Exception(f"""Failed to validate dataset: {name}
Hash mismatch: {hash} should be {known_hash}
Please check your local file for tampering or corruption.""")
raise IntegrityError(file, known_hash, hash)
def onehot(ind):
@ -113,9 +114,13 @@ def read_array(f):
return np.array(array.array("B", f.read()), dtype=np.uint8)
def open_maybe_gzip(path):
opener = gzip.open if path.endswith(".gz") else open
return opener(path, "rb")
def open_maybe_gzip(path, flags="rb"):
if path.endswith(".gz"):
logger.debug("Opening %s with gzip.open", path)
return gzip.open(path, flags)
else:
logger.debug("Opening %s with builtin open", path)
return open(path, flags)
def load(images_path, labels_path):
@ -150,7 +155,7 @@ def batch_flatten(x):
def prepare(dataset="mnist", return_floats=True, return_onehot=True,
flatten=False, squeeze_labels=True, check_integrity=True):
if dataset not in metadata.keys():
raise Exception(f"Unknown dataset: {dataset}")
raise UnknownDatasetError(dataset)
meta = metadata[dataset]
prefix, names = meta[0], meta[1:]
@ -158,6 +163,9 @@ def prepare(dataset="mnist", return_floats=True, return_onehot=True,
npz, train_images, train_labels, test_images, test_labels = names
images_and_labels = names[1:]
logger.debug("Filenames chosen for %s: %s, %s, %s, %s",
dataset, train_images, train_labels, test_images, test_labels)
make_directories()
existing = [os.path.isfile(construct_path(name)) for name in names]
@ -169,6 +177,7 @@ def prepare(dataset="mnist", return_floats=True, return_onehot=True,
validate(name)
if npz_existing:
logger.info("Loading npz file %s", npz)
with np.load(construct_path(npz)):
train_images_data, train_labels_data, \
test_images_data, test_labels_data = \

View File

@ -1,5 +1,10 @@
from . import metadata, prepare
import logging
from . import metadata, prepare, logger
logging.basicConfig()
logger.setLevel(logging.DEBUG)
headers = ("subdirectory",
"dataset",

18
mnists/exceptions.py Normal file
View File

@ -0,0 +1,18 @@
class IntegrityError(Exception):
def __init__(self, file, expected_hash, computed_hash):
self.file = file
self.expected_hash = expected_hash
self.computed_hash = computed_hash
def __str__(self):
return f"""Failed to validate dataset: {name}
Hash mismatch: {self.computed_hash} should be {self.expected_hash}
Please check your local file for tampering or corruption."""
class UnknownDatasetError(Exception):
def __init__(self, dataset):
self.dataset = dataset
def __str__(self):
return f"Unknown mnist-like dataset: {dataset}"