数据给出的是小麦的特征数据以及类型。每一个样本由7个特征属性组成,即可以看做7维空间的一个点。我们通过计算两个样本的距离来度量样品间的相似度。在分类时,采用一个简单的规则:对于一个新的样本,我们在数据集中找到最接近它的点,然后将该样本归为和它最近点的同一标签。并采用10折交叉验证。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# __author__ : '小糖果'
import numpy as np
import matplotlib.pyplot as plt
class KnnRecomender(object):
def __init__(self,fname,k = 1):
data = []
labels = []
with open(fname) as f:
for line in f:
d = line.split()
data.append([float(x.strip()) for x in d[:-1]])
labels.append(d[-1].strip())
self.features = np.array(data)
self.labels = np.array(labels)
self.k = k
self.acc = 0.0
def plurality(self, results):
counts = {}
for v in results:
counts.setdefault(v,0)
counts[v] += 1
maxc = max(counts.values())
for k,v in counts.items():
if maxc == v:
return k
def applyModel(self,testing_feats, model):
training_feats,labels = model
results = []
for f in testing_feats:
d = []
for t,label in zip(training_feats,labels):
dis = np.linalg.norm(f-t)
d.append((dis,label))
d.sort()
d = d[:self.k]
results.append(self.plurality([label for dis,label in d]))
return np.array(results)
def accuracy(self,test_model,learn_model):
preds = self.applyModel(test_model[0],learn_model)
acc = np.mean(preds == test_model[1])
return acc
def crossValidata(self):
self.acc = 0
for fold in range(10):
# 采用10折交叉验证
training = np.ones(self.features.shape[0],bool)
training[fold::10] = 0
testing = ~training
learn_model = (self.features[training].copy(),
self.labels[training].copy())
test_model = (self.features[testing].copy(),
self.labels[testing].copy())
self.acc += self.accuracy(test_model,learn_model)
self.acc /= 10
def standard(self):
m = self.features.mean(axis = 0)
s = self.features.std(axis = 0)
self.features = (self.features - m)/s
def test():
fpath = r'C:\Users\TD\Desktop\data\Machine Learning\1400OS_02_Codes\data\seeds.tsv'
instance = KnnRecomender(fpath)
instance.crossValidata()
print "the accuracy is {:.2f}%".format(instance.acc * 100)
# 将数据标准化后再测试
instance.standard()
instance.crossValidata()
print "the accuracy is {:.2f}%".format(instance.acc*100)
if __name__ == '__main__':
test()
结果得到:
the accuracy is 89.52% (没有标准化)
the accuracy is 94.29% (标准化后)
。