使用sklearn的KNN实现类,neighbors.KNeighborsClassifier,模型精度达到96.7%
数据集可以在线下载,也可以手动下载:
mnist数据集地址:https://www.lanzouw.com/iXDefxnl3fa
import torch, torchvision
from sklearn import neighbors
#加载mnist数据集
train_dataset = torchvision.datasets.MNIST(root='./data/',
train=True, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data/',
train=False, download=False)
#获取mnist数据集, 并进行归一化,然后将(28*,28)的图片转成(1, 784)向量
train_data = (train_dataset.data/255).view(-1, 784)
train_label = train_dataset.targets
#加载测试集
test_data = (test_dataset.data/255).view(-1, 784)
test_label = test_dataset.targets
#训练模型
model = neighbors.KNeighborsClassifier(n_neighbors=8)
model.fit(train_data, train_label)
#模型预测
predict = model.predict(test_data)
#使用sklearn的score函数算精度,
acc = model.score(test_data, test_label)
print(acc)