代码待测试,数据需下载
import numpy as np
import operator
import matplotlib.pyplot as plt
import torch
from torch.utile.data import DataLoader
import torchvision.datasets as dsets
class Knn:
def __init__(self):
pass
def fit(self,X_train,Y_train):
self.Xtr=X_train
self.Ytr=Y_train
def predict(self,k,dis,X_test):
assert dis == "E" or dis=="M", 'dis must Eor M'
num_test=X_test.shape[0]
labellist=[]
#使用欧式距离公式作为距离度量
if (dis=='E'):
for i in range(num_test):
distances = np.sqrt(np.sum(((self.Xtr-np.tile(X_test[i],(self.Xtr.shape[0],1)))**2),axis=1))
nearest_k=np.argsort(distances)
topk = nearest_k[: