add Exception subclasses and basic logging
This commit is contained in:
parent
cbc9c4d7e4
commit
6fe25b1157
3 changed files with 47 additions and 15 deletions
|
@ -10,6 +10,7 @@ __version__ = "0.2.1"
|
|||
import array
|
||||
import gzip
|
||||
import hashlib
|
||||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import os.path
|
||||
|
@ -18,17 +19,16 @@ import sys
|
|||
from urllib.request import urlretrieve
|
||||
|
||||
from .hashes import hashes
|
||||
from .exceptions import *
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
home = os.path.expanduser("~")
|
||||
output_directory = os.path.join(home, ".mnist")
|
||||
webhost = "https://eaguru.guru/mnist/"
|
||||
|
||||
|
||||
def lament(*args, **kwargs):
|
||||
print(*args, file=sys.stderr, **kwargs)
|
||||
|
||||
|
||||
def make_meta(npz, train_part="train", test_part="t10k", prefix=""):
|
||||
images_suffix = "-images-idx3-ubyte.gz"
|
||||
labels_suffix = "-labels-idx1-ubyte.gz"
|
||||
|
@ -68,6 +68,7 @@ def make_directories():
|
|||
directories += [construct_path(meta[0]) for meta in metadata.values()]
|
||||
for directory in directories:
|
||||
if not os.path.isdir(directory):
|
||||
logger.info("Creating directory %s", directory)
|
||||
os.mkdir(directory)
|
||||
|
||||
|
||||
|
@ -75,19 +76,21 @@ def download(name):
|
|||
url_name = "/".join(os.path.split(name))
|
||||
path = construct_path(name)
|
||||
already_exists = os.path.isfile(path)
|
||||
if not already_exists:
|
||||
if already_exists:
|
||||
logger.debug("Not downloading %s, it already exists: %s", name, path)
|
||||
else:
|
||||
url = webhost + url_name
|
||||
try:
|
||||
urlretrieve(url, path)
|
||||
except Exception:
|
||||
lament(f"Failed to download {url} to {path}")
|
||||
logger.warning("Failed to download %s to %s", url, path)
|
||||
raise
|
||||
return already_exists
|
||||
|
||||
|
||||
def validate(name):
|
||||
if name not in hashes.keys():
|
||||
raise Exception(f"Unknown mnist dataset: {name}")
|
||||
raise UnknownDatasetError(name)
|
||||
|
||||
with open(construct_path(name), "rb") as f:
|
||||
data = f.read()
|
||||
|
@ -96,9 +99,7 @@ def validate(name):
|
|||
hash = hashlib.sha256(data).hexdigest()
|
||||
|
||||
if hash != known_hash:
|
||||
raise Exception(f"""Failed to validate dataset: {name}
|
||||
Hash mismatch: {hash} should be {known_hash}
|
||||
Please check your local file for tampering or corruption.""")
|
||||
raise IntegrityError(file, known_hash, hash)
|
||||
|
||||
|
||||
def onehot(ind):
|
||||
|
@ -113,9 +114,13 @@ def read_array(f):
|
|||
return np.array(array.array("B", f.read()), dtype=np.uint8)
|
||||
|
||||
|
||||
def open_maybe_gzip(path):
|
||||
opener = gzip.open if path.endswith(".gz") else open
|
||||
return opener(path, "rb")
|
||||
def open_maybe_gzip(path, flags="rb"):
|
||||
if path.endswith(".gz"):
|
||||
logger.debug("Opening %s with gzip.open", path)
|
||||
return gzip.open(path, flags)
|
||||
else:
|
||||
logger.debug("Opening %s with builtin open", path)
|
||||
return open(path, flags)
|
||||
|
||||
|
||||
def load(images_path, labels_path):
|
||||
|
@ -150,7 +155,7 @@ def batch_flatten(x):
|
|||
def prepare(dataset="mnist", return_floats=True, return_onehot=True,
|
||||
flatten=False, squeeze_labels=True, check_integrity=True):
|
||||
if dataset not in metadata.keys():
|
||||
raise Exception(f"Unknown dataset: {dataset}")
|
||||
raise UnknownDatasetError(dataset)
|
||||
|
||||
meta = metadata[dataset]
|
||||
prefix, names = meta[0], meta[1:]
|
||||
|
@ -158,6 +163,9 @@ def prepare(dataset="mnist", return_floats=True, return_onehot=True,
|
|||
npz, train_images, train_labels, test_images, test_labels = names
|
||||
images_and_labels = names[1:]
|
||||
|
||||
logger.debug("Filenames chosen for %s: %s, %s, %s, %s",
|
||||
dataset, train_images, train_labels, test_images, test_labels)
|
||||
|
||||
make_directories()
|
||||
|
||||
existing = [os.path.isfile(construct_path(name)) for name in names]
|
||||
|
@ -169,6 +177,7 @@ def prepare(dataset="mnist", return_floats=True, return_onehot=True,
|
|||
validate(name)
|
||||
|
||||
if npz_existing:
|
||||
logger.info("Loading npz file %s", npz)
|
||||
with np.load(construct_path(npz)):
|
||||
train_images_data, train_labels_data, \
|
||||
test_images_data, test_labels_data = \
|
||||
|
|
|
@ -1,5 +1,10 @@
|
|||
from . import metadata, prepare
|
||||
import logging
|
||||
|
||||
from . import metadata, prepare, logger
|
||||
|
||||
|
||||
logging.basicConfig()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
headers = ("subdirectory",
|
||||
"dataset",
|
||||
|
|
18
mnists/exceptions.py
Normal file
18
mnists/exceptions.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
class IntegrityError(Exception):
|
||||
def __init__(self, file, expected_hash, computed_hash):
|
||||
self.file = file
|
||||
self.expected_hash = expected_hash
|
||||
self.computed_hash = computed_hash
|
||||
|
||||
def __str__(self):
|
||||
return f"""Failed to validate dataset: {name}
|
||||
Hash mismatch: {self.computed_hash} should be {self.expected_hash}
|
||||
Please check your local file for tampering or corruption."""
|
||||
|
||||
|
||||
class UnknownDatasetError(Exception):
|
||||
def __init__(self, dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
def __str__(self):
|
||||
return f"Unknown mnist-like dataset: {dataset}"
|
Loading…
Add table
Reference in a new issue