remove incomplete npz functionality
This commit is contained in:
parent
c7da7e5fbf
commit
6efccb6b20
2 changed files with 20 additions and 30 deletions
4
TODO
4
TODO
|
@ -2,8 +2,6 @@ TODO
|
||||||
|
|
||||||
* finish writing README
|
* finish writing README
|
||||||
|
|
||||||
* finish npz functionality
|
|
||||||
|
|
||||||
* document prepare() function
|
* document prepare() function
|
||||||
|
|
||||||
* support python 3.5
|
* support python 3.5
|
||||||
|
@ -17,3 +15,5 @@ TODO
|
||||||
* document everything
|
* document everything
|
||||||
|
|
||||||
* test everything
|
* test everything
|
||||||
|
|
||||||
|
* consider npz functionality if it reduces preprocessing time
|
||||||
|
|
|
@ -5,8 +5,6 @@
|
||||||
|
|
||||||
__version__ = "0.2.1"
|
__version__ = "0.2.1"
|
||||||
|
|
||||||
# NOTE: npz functionality is incomplete.
|
|
||||||
|
|
||||||
import array
|
import array
|
||||||
import gzip
|
import gzip
|
||||||
import hashlib
|
import hashlib
|
||||||
|
@ -29,29 +27,29 @@ output_directory = os.path.join(home, ".mnist")
|
||||||
webhost = "https://eaguru.guru/mnist/"
|
webhost = "https://eaguru.guru/mnist/"
|
||||||
|
|
||||||
|
|
||||||
def _make_meta(npz, train_part="train", test_part="t10k", prefix=""):
|
def _make_meta(train_part="train", test_part="t10k", prefix=""):
|
||||||
images_suffix = "-images-idx3-ubyte.gz"
|
images_suffix = "-images-idx3-ubyte.gz"
|
||||||
labels_suffix = "-labels-idx1-ubyte.gz"
|
labels_suffix = "-labels-idx1-ubyte.gz"
|
||||||
return (prefix, npz,
|
return (prefix,
|
||||||
train_part + images_suffix,
|
train_part + images_suffix,
|
||||||
train_part + labels_suffix,
|
train_part + labels_suffix,
|
||||||
test_part + images_suffix,
|
test_part + images_suffix,
|
||||||
test_part + labels_suffix)
|
test_part + labels_suffix)
|
||||||
|
|
||||||
|
|
||||||
def _emnist_meta(npz, name):
|
def _emnist_meta(name):
|
||||||
return _make_meta(npz, name + "-train", name + "-test", prefix="emnist")
|
return _make_meta(name + "-train", name + "-test", prefix="emnist")
|
||||||
|
|
||||||
|
|
||||||
metadata = dict(
|
metadata = dict(
|
||||||
emnist_balanced=_emnist_meta("emnist_balanced.npz", "emnist-balanced"),
|
emnist_balanced=_emnist_meta("emnist-balanced"),
|
||||||
emnist_byclass=_emnist_meta("emnist_byclass.npz", "emnist-byclass"),
|
emnist_byclass=_emnist_meta("emnist-byclass"),
|
||||||
emnist_bymerge=_emnist_meta("emnist_bymerge.npz", "emnist-bymerge"),
|
emnist_bymerge=_emnist_meta("emnist-bymerge"),
|
||||||
emnist_digits=_emnist_meta("emnist_digits.npz", "emnist-digits"),
|
emnist_digits=_emnist_meta("emnist-digits"),
|
||||||
emnist_letters=_emnist_meta("emnist_letters.npz", "emnist-letters"),
|
emnist_letters=_emnist_meta("emnist-letters"),
|
||||||
emnist_mnist=_emnist_meta("emnist_mnist.npz", "emnist-mnist"),
|
emnist_mnist=_emnist_meta("emnist-mnist"),
|
||||||
fashion_mnist=_make_meta("fashion_mnist.npz", prefix="fashion-mnist"),
|
fashion_mnist=_make_meta(prefix="fashion-mnist"),
|
||||||
mnist=_make_meta("mnist.npz", prefix="mnist"),
|
mnist=_make_meta(prefix="mnist"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -160,7 +158,7 @@ def prepare(dataset="mnist", return_floats=True, return_onehot=True,
|
||||||
meta = metadata[dataset]
|
meta = metadata[dataset]
|
||||||
prefix, names = meta[0], meta[1:]
|
prefix, names = meta[0], meta[1:]
|
||||||
names = [os.path.join(prefix, name) for name in names]
|
names = [os.path.join(prefix, name) for name in names]
|
||||||
npz, train_images, train_labels, test_images, test_labels = names
|
train_images, train_labels, test_images, test_labels = names
|
||||||
images_and_labels = names[1:]
|
images_and_labels = names[1:]
|
||||||
|
|
||||||
logger.debug("Filenames chosen for %s: %s, %s, %s, %s",
|
logger.debug("Filenames chosen for %s: %s, %s, %s, %s",
|
||||||
|
@ -169,21 +167,13 @@ def prepare(dataset="mnist", return_floats=True, return_onehot=True,
|
||||||
make_directories()
|
make_directories()
|
||||||
|
|
||||||
existing = [os.path.isfile(construct_path(name)) for name in names]
|
existing = [os.path.isfile(construct_path(name)) for name in names]
|
||||||
npz_existing, gz_existing = existing[0], existing[1:]
|
gz_existing = existing[0], existing[1:]
|
||||||
|
|
||||||
for name in images_and_labels:
|
for name in images_and_labels:
|
||||||
download(name)
|
download(name)
|
||||||
if check_integrity:
|
if check_integrity:
|
||||||
validate(name)
|
validate(name)
|
||||||
|
|
||||||
if npz_existing:
|
|
||||||
logger.info("Loading npz file %s", npz)
|
|
||||||
with np.load(construct_path(npz)):
|
|
||||||
train_images_data, train_labels_data, \
|
|
||||||
test_images_data, test_labels_data = \
|
|
||||||
f["train_images"], f["train_labels"], \
|
|
||||||
f["test_images"], f["test_labels"]
|
|
||||||
else:
|
|
||||||
train_images_data, train_labels_data = load(
|
train_images_data, train_labels_data = load(
|
||||||
construct_path(train_images), construct_path(train_labels))
|
construct_path(train_images), construct_path(train_labels))
|
||||||
test_images_data, test_labels_data = load(
|
test_images_data, test_labels_data = load(
|
||||||
|
|
Loading…
Add table
Reference in a new issue