在用tensorflow学习简单的神经网络时,报如下错误:
DeprecationWarning: elementwise == comparison failed; this will raise an error
原因是:在处理数据时,我运行了如下代码后,又重新粘贴了一份,再次运行,使得数据变成了三维,造成了数据的不匹配:
代码:
image_size = 28
num_labels = 10
def reformat(dataset, labels):
dataset = dataset.reshape((-1, image_size * image_size)).astype(np.float32)
# Map 0 to [1.0, 0.0, 0.0 ...], 1 to [0.0, 1.0, 0.0 ...]
labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32)
return dataset, labels
train_dataset, train_labels = reformat(train_dataset, train_labels)
valid_dataset, valid_labels = reformat(valid_dataset, valid_labels)
test_dataset, test_labels = reformat(test_dataset, test_labels)
print('Training set', train_dataset.shape, train_labels.shape)
print('Validation set', valid_dataset.shape, valid_labels.shape)
print('Test set', test_dataset.shape, test_labels.shape)
正确结果:
Training set (200000, 784) (200000, 10)
Validation set (10000, 784) (10000, 10)
Test set (10000, 784) (10000, 10)
运行错误结果:
Training set (200000, 784) (200000, 1, 10)
Validation set (10000, 784) (10000, 1, 10)
Test set (10000, 784) (10000, 1, 10)
很明显,数据的维度有错
解决方法:重新运行第一遍的程序即可,只允许一遍即可。