fix typo: squeeze_layers -> squeeze_labels
This commit is contained in:
parent
d6e4a9bd3f
commit
cbc9c4d7e4
|
@ -148,7 +148,7 @@ def batch_flatten(x):
|
||||||
|
|
||||||
|
|
||||||
def prepare(dataset="mnist", return_floats=True, return_onehot=True,
|
def prepare(dataset="mnist", return_floats=True, return_onehot=True,
|
||||||
flatten=False, squeeze_layers=True, check_integrity=True):
|
flatten=False, squeeze_labels=True, check_integrity=True):
|
||||||
if dataset not in metadata.keys():
|
if dataset not in metadata.keys():
|
||||||
raise Exception(f"Unknown dataset: {dataset}")
|
raise Exception(f"Unknown dataset: {dataset}")
|
||||||
|
|
||||||
|
@ -191,7 +191,7 @@ def prepare(dataset="mnist", return_floats=True, return_onehot=True,
|
||||||
|
|
||||||
# emnist_letters uses labels indexed from 1 instead of the usual 0.
|
# emnist_letters uses labels indexed from 1 instead of the usual 0.
|
||||||
# the onehot function assumes labels are contiguous from 0.
|
# the onehot function assumes labels are contiguous from 0.
|
||||||
if squeeze_layers or return_onehot:
|
if squeeze_labels or return_onehot:
|
||||||
train_labels_data = squeeze(train_labels_data)
|
train_labels_data = squeeze(train_labels_data)
|
||||||
test_labels_data = squeeze(test_labels_data)
|
test_labels_data = squeeze(test_labels_data)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user