本文参考了机器学习实战3-sklearn使用下载MNIST数据集进行分类项目,针对我使用该教程中出现的问题提出解决方案。
1.数据集加载
原文中使用了如下代码进行数据集的加载:
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')
mnist
在更高级版本的sklearn.datasets中已经没有fetch_mldata()方法,使用该方法进行数据集的加载则会报错。
此时应该使用如下方法进行数据集的加载:
from sklearn.datasets import fetch_openml
mnist = fetch_openml("mnist_784", data_home="./")
print(mnist)
data_home根据数据集存放的位置进行更改。
代码问题
在原文的2.1、随机梯度下降(SGD)分类器中有如下代码:
y_train_5 = (y_train == 5) # True for all 5s, False for all other digits.
y_test_5 = (y_test == 5)
在y_train和y_test中保存的数字全都是以字符串的形式,使用上述代码的结果是y_train_5,y_test_5中的所有数据都是false,因此修改代码为:
y_train_5 = (y_train == '5')
y_test_5 = (y_test =='5')
希望能够有所帮助。