add Exception subclasses and basic logging
This commit is contained in:
parent
cbc9c4d7e4
commit
6fe25b1157
|
@ -10,6 +10,7 @@ __version__ = "0.2.1"
|
||||||
import array
|
import array
|
||||||
import gzip
|
import gzip
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import os.path
|
import os.path
|
||||||
|
@ -18,17 +19,16 @@ import sys
|
||||||
from urllib.request import urlretrieve
|
from urllib.request import urlretrieve
|
||||||
|
|
||||||
from .hashes import hashes
|
from .hashes import hashes
|
||||||
|
from .exceptions import *
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
home = os.path.expanduser("~")
|
home = os.path.expanduser("~")
|
||||||
output_directory = os.path.join(home, ".mnist")
|
output_directory = os.path.join(home, ".mnist")
|
||||||
webhost = "https://eaguru.guru/mnist/"
|
webhost = "https://eaguru.guru/mnist/"
|
||||||
|
|
||||||
|
|
||||||
def lament(*args, **kwargs):
|
|
||||||
print(*args, file=sys.stderr, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def make_meta(npz, train_part="train", test_part="t10k", prefix=""):
|
def make_meta(npz, train_part="train", test_part="t10k", prefix=""):
|
||||||
images_suffix = "-images-idx3-ubyte.gz"
|
images_suffix = "-images-idx3-ubyte.gz"
|
||||||
labels_suffix = "-labels-idx1-ubyte.gz"
|
labels_suffix = "-labels-idx1-ubyte.gz"
|
||||||
|
@ -68,6 +68,7 @@ def make_directories():
|
||||||
directories += [construct_path(meta[0]) for meta in metadata.values()]
|
directories += [construct_path(meta[0]) for meta in metadata.values()]
|
||||||
for directory in directories:
|
for directory in directories:
|
||||||
if not os.path.isdir(directory):
|
if not os.path.isdir(directory):
|
||||||
|
logger.info("Creating directory %s", directory)
|
||||||
os.mkdir(directory)
|
os.mkdir(directory)
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,19 +76,21 @@ def download(name):
|
||||||
url_name = "/".join(os.path.split(name))
|
url_name = "/".join(os.path.split(name))
|
||||||
path = construct_path(name)
|
path = construct_path(name)
|
||||||
already_exists = os.path.isfile(path)
|
already_exists = os.path.isfile(path)
|
||||||
if not already_exists:
|
if already_exists:
|
||||||
|
logger.debug("Not downloading %s, it already exists: %s", name, path)
|
||||||
|
else:
|
||||||
url = webhost + url_name
|
url = webhost + url_name
|
||||||
try:
|
try:
|
||||||
urlretrieve(url, path)
|
urlretrieve(url, path)
|
||||||
except Exception:
|
except Exception:
|
||||||
lament(f"Failed to download {url} to {path}")
|
logger.warning("Failed to download %s to %s", url, path)
|
||||||
raise
|
raise
|
||||||
return already_exists
|
return already_exists
|
||||||
|
|
||||||
|
|
||||||
def validate(name):
|
def validate(name):
|
||||||
if name not in hashes.keys():
|
if name not in hashes.keys():
|
||||||
raise Exception(f"Unknown mnist dataset: {name}")
|
raise UnknownDatasetError(name)
|
||||||
|
|
||||||
with open(construct_path(name), "rb") as f:
|
with open(construct_path(name), "rb") as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
|
@ -96,9 +99,7 @@ def validate(name):
|
||||||
hash = hashlib.sha256(data).hexdigest()
|
hash = hashlib.sha256(data).hexdigest()
|
||||||
|
|
||||||
if hash != known_hash:
|
if hash != known_hash:
|
||||||
raise Exception(f"""Failed to validate dataset: {name}
|
raise IntegrityError(file, known_hash, hash)
|
||||||
Hash mismatch: {hash} should be {known_hash}
|
|
||||||
Please check your local file for tampering or corruption.""")
|
|
||||||
|
|
||||||
|
|
||||||
def onehot(ind):
|
def onehot(ind):
|
||||||
|
@ -113,9 +114,13 @@ def read_array(f):
|
||||||
return np.array(array.array("B", f.read()), dtype=np.uint8)
|
return np.array(array.array("B", f.read()), dtype=np.uint8)
|
||||||
|
|
||||||
|
|
||||||
def open_maybe_gzip(path):
|
def open_maybe_gzip(path, flags="rb"):
|
||||||
opener = gzip.open if path.endswith(".gz") else open
|
if path.endswith(".gz"):
|
||||||
return opener(path, "rb")
|
logger.debug("Opening %s with gzip.open", path)
|
||||||
|
return gzip.open(path, flags)
|
||||||
|
else:
|
||||||
|
logger.debug("Opening %s with builtin open", path)
|
||||||
|
return open(path, flags)
|
||||||
|
|
||||||
|
|
||||||
def load(images_path, labels_path):
|
def load(images_path, labels_path):
|
||||||
|
@ -150,7 +155,7 @@ def batch_flatten(x):
|
||||||
def prepare(dataset="mnist", return_floats=True, return_onehot=True,
|
def prepare(dataset="mnist", return_floats=True, return_onehot=True,
|
||||||
flatten=False, squeeze_labels=True, check_integrity=True):
|
flatten=False, squeeze_labels=True, check_integrity=True):
|
||||||
if dataset not in metadata.keys():
|
if dataset not in metadata.keys():
|
||||||
raise Exception(f"Unknown dataset: {dataset}")
|
raise UnknownDatasetError(dataset)
|
||||||
|
|
||||||
meta = metadata[dataset]
|
meta = metadata[dataset]
|
||||||
prefix, names = meta[0], meta[1:]
|
prefix, names = meta[0], meta[1:]
|
||||||
|
@ -158,6 +163,9 @@ def prepare(dataset="mnist", return_floats=True, return_onehot=True,
|
||||||
npz, train_images, train_labels, test_images, test_labels = names
|
npz, train_images, train_labels, test_images, test_labels = names
|
||||||
images_and_labels = names[1:]
|
images_and_labels = names[1:]
|
||||||
|
|
||||||
|
logger.debug("Filenames chosen for %s: %s, %s, %s, %s",
|
||||||
|
dataset, train_images, train_labels, test_images, test_labels)
|
||||||
|
|
||||||
make_directories()
|
make_directories()
|
||||||
|
|
||||||
existing = [os.path.isfile(construct_path(name)) for name in names]
|
existing = [os.path.isfile(construct_path(name)) for name in names]
|
||||||
|
@ -169,6 +177,7 @@ def prepare(dataset="mnist", return_floats=True, return_onehot=True,
|
||||||
validate(name)
|
validate(name)
|
||||||
|
|
||||||
if npz_existing:
|
if npz_existing:
|
||||||
|
logger.info("Loading npz file %s", npz)
|
||||||
with np.load(construct_path(npz)):
|
with np.load(construct_path(npz)):
|
||||||
train_images_data, train_labels_data, \
|
train_images_data, train_labels_data, \
|
||||||
test_images_data, test_labels_data = \
|
test_images_data, test_labels_data = \
|
||||||
|
|
|
@ -1,5 +1,10 @@
|
||||||
from . import metadata, prepare
|
import logging
|
||||||
|
|
||||||
|
from . import metadata, prepare, logger
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig()
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
headers = ("subdirectory",
|
headers = ("subdirectory",
|
||||||
"dataset",
|
"dataset",
|
||||||
|
|
18
mnists/exceptions.py
Normal file
18
mnists/exceptions.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
class IntegrityError(Exception):
|
||||||
|
def __init__(self, file, expected_hash, computed_hash):
|
||||||
|
self.file = file
|
||||||
|
self.expected_hash = expected_hash
|
||||||
|
self.computed_hash = computed_hash
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"""Failed to validate dataset: {name}
|
||||||
|
Hash mismatch: {self.computed_hash} should be {self.expected_hash}
|
||||||
|
Please check your local file for tampering or corruption."""
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownDatasetError(Exception):
|
||||||
|
def __init__(self, dataset):
|
||||||
|
self.dataset = dataset
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"Unknown mnist-like dataset: {dataset}"
|
Loading…
Reference in New Issue
Block a user