fix typo: squeeze_layers -> squeeze_labels
This commit is contained in:
parent
d6e4a9bd3f
commit
cbc9c4d7e4
1 changed files with 2 additions and 2 deletions
|
@ -148,7 +148,7 @@ def batch_flatten(x):
|
|||
|
||||
|
||||
def prepare(dataset="mnist", return_floats=True, return_onehot=True,
|
||||
flatten=False, squeeze_layers=True, check_integrity=True):
|
||||
flatten=False, squeeze_labels=True, check_integrity=True):
|
||||
if dataset not in metadata.keys():
|
||||
raise Exception(f"Unknown dataset: {dataset}")
|
||||
|
||||
|
@ -191,7 +191,7 @@ def prepare(dataset="mnist", return_floats=True, return_onehot=True,
|
|||
|
||||
# emnist_letters uses labels indexed from 1 instead of the usual 0.
|
||||
# the onehot function assumes labels are contiguous from 0.
|
||||
if squeeze_layers or return_onehot:
|
||||
if squeeze_labels or return_onehot:
|
||||
train_labels_data = squeeze(train_labels_data)
|
||||
test_labels_data = squeeze(test_labels_data)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue