书中有一些代码无法运行,防止踩坑,在此纠正。
旧版本中调用MNIST数据集是
>>> from sklearn.datasets import fetch_mldata
>>> mnist = fetch_mldata('MNIST original')
新版本中调用MNIST数据集应该是
from sklearn.datasets import fetch_openml
mnist_data = fetch_openml("mnist_784")
旧版本中查看其中一个数据是
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
some_digit = X[36000]
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap = matplotlib.cm.binary, interpolation="nearest")
plt.axis("off")
plt.show()
新版本中查看其中一个数据应该是
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
some_digit = np.array(X.iloc[36000,])
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap = matplotlib.cm.binary, interpolation="nearest")
plt.axis("off")
plt.show()
特别注意,MNIST数据集标签是字符串类型
y_train_5 = (y_train == '5' )
y_test_5 = (y_test == '5')