diff --git a/mnists/__init__.py b/mnists/__init__.py index 3b504e6..cb4efe9 100644 --- a/mnists/__init__.py +++ b/mnists/__init__.py @@ -146,7 +146,7 @@ def batch_flatten(x): def prepare(dataset="mnist", return_floats=True, return_onehot=True, - flatten=False, squeeze_layers=True): + flatten=False, squeeze_layers=True, check_integrity=True): if dataset not in metadata.keys(): raise Exception(f"Unknown dataset: {dataset}") @@ -163,7 +163,8 @@ def prepare(dataset="mnist", return_floats=True, return_onehot=True, for name in images_and_labels: download(name) - validate(name) + if check_integrity: + validate(name) if npz_existing: with np.load(construct_path(npz)):