import collections
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
def file2matrix(filename):
"""
:param filename: 数据文件路径
:return: 数据矩阵和对应的类别
"""
fr = open(filename)
numberOfLines = len(fr.readlines())
fr.close()
with open(filename) as f:
# numberOfLines = len(f.readlines())
returnMat = np.zeros((numberOfLines,3))
classLabelVector = []
index = 0
for line in f.readlines():
line = line.strip()
listFromLine = line.split('\t')
returnMat[index,:] = listFromLine[:3]
classLabelVector.append(int(listFromLine[-1]))
index+=1
return returnMat,classLabelVector
def autoNorm(dataSet):
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
r = maxVals - minVals
normalData = np.zeros(dataSet.shape)
normalData = (dataSet - minVals) / (r)
return normalData, r, minVals
def classify0(inX,dataSet,labels,k):
distance = np.sum((inX - dataSet)**2,axis= 1)**0.5
k_lables = [labels[idx] for idx in distance.argsort()[:k]]
label = collections.Counter(k_lables).most_common(1)[0][0]
return label
def datingClassTest():
ratio = 0.1
datingDataMat,datingLabels = file2matrix('./datingTestSet2.txt')
m = datingDataMat.shape[0]
numTest = int(m*ratio)
err = 0
for i in range(numTest):
res = classify0(datingDataMat[i,:],datingDataMat[numTest:,:],datingLabels[numTest:],1)
print()
print("predict: %d, true is : %d" % (res,datingLabels[i]))
if res != datingLabels[i]:
err+=1
print("total err: %f" % (err/numTest))
def createDataset():
group = np.array([ [1,101],[5,89],[108,5],[115,8]])
labels = [1,1,0,0]
return group,labels
def classify(test,group,labels,k):
dis = np.sum((test-group)**2,axis=1)**0.5
a = dis.argsort().tolist()
print((dis.argsort().tolist()))
# print(a.reverse())
k_labels = [labels[idx] for idx in dis.argsort()[:k]]
print(k_labels)
a = collections.Counter(k_labels)
print(a)
print(a.most_common(1)[0][0])
if __name__=='__main__':
# group,labels = createDataset()
#
# test = [101,20]
#
# classify(test,group,labels,3)
#
# mat,label = file2matrix('./datingTestSet2.txt')
#
# print(mat)
#
# fig = plt.figure()
#
# ax = fig.add_subplot(111)
#
# ax.scatter(mat[:,0],mat[:,1],15.0*np.array(label),10.0*np.array(label))
# # ax.scatter(mat[:,0],mat[:,1])
#
# # ax.legend((ax),()loc= 'upper left')
# ax.legend()
#
# plt.show()
datingClassTest()
08-04
08-04
08-04
08-04
08-04
08-04
08-04
08-04