diff --git a/mnists/__init__.py b/mnists/__init__.py index f43bc3e..9920bd3 100644 --- a/mnists/__init__.py +++ b/mnists/__init__.py @@ -218,10 +218,8 @@ def prepare(dataset="mnist", return_floats=True, return_onehot=True, test_images_data = test_images_data.transpose(0,1,3,2) if return_floats: # TODO: better name. - train_images_data = train_images_data.astype(np.float32) / np.float32(255) - test_images_data = test_images_data.astype(np.float32) / np.float32(255) - assert train_images_data.dtype == 'float32' - assert test_images_data.dtype == 'float32' + train_images_data = train_images_data.astype(np.float32) / 255 + test_images_data = test_images_data.astype(np.float32) / 255 # emnist_letters uses labels indexed from 1 instead of the usual 0. # the onehot function assumes labels are contiguous from 0.