代码:
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 23 11:48:16 2018
@author: 安颖
"""
import numpy as np
import matplotlib.pyplot as plt
#计算两样本间距离
def getinstance(x,y):
#欧氏距离求平方和
for i in range(4):
inst = pow(x[i]-y[i],2)
return np.sqrt(inst)
#knn 算法
def knn(x,t_train,k):
#距离矩阵
disarr = []
for i in range(len(t_train)):
disarr.append(getinstance(x,t_train[i]))
#argsort函数返回的是数组值从小到大的索引值
index = list(np.argsort(disarr))
#记录k近邻点属于不同类的数目
indexlable = [0]*3
for i in range(k):
#y为第i近的点
y = t_train[index[i]]
if y[4] == 1:
indexlable[0] += 1
elif y[4] == 2:
indexlable[1] += 1
else:
indexlable[2] += 1
lable = np.argsort(indexlable)
#返回同一类中最多的类
return (lable[2]+1)
# k_fold 产生一个迭代器
#n为数据集,n_folds为折数
def k_fold(data,n_folds):
#根据数据集格式,1/2/3类分别集中在三段
label_a = data[:50]
label_b = data[50:100]
label_c = data[100:]
#构建一个折数为n_folds的数组,每组中的数为均分的数目,用于k折交叉验证循环产生训练集和测试集
fold_sizes = (50 / n_folds) * np.ones(n_folds, dtype=np.int)
#开始位置为0
current = 0
#循环产生训练集和测试集
for fold_size in fold_sizes:
start, stop = current, current + int(fold_size)
#拼接训练集
train_index = list(np.concatenate((label_a[:start], label_a[stop:])))
index1 = list(np.concatenate((label_b[:start], label_b[stop:])))
index2 = list(np.concatenate((label_c[:start], label_c[stop:])))
train_index.extend(index1)
train_index.extend(index2)
#拼接测试集
test_index = list(label_a[start:stop])
test_index.extend(label_b[start:stop])
test_index.extend(label_c[start:stop])
#yield 传回函数 用.next()或在循环中调用
yield train_index, test_index
current = stop # move one step forward
#交叉验证
#循环计算错误率求均值
def cross_validate(X,kn):
error = 0.0
n_folds=5
kf = k_fold(X, n_folds)
for train_index, test_index in kf:
error += error_rate(test_index,train_index,kn)
return round(error/n_folds , 2)
#计算错误率
def error_rate(test,train,k):
error = 0
for i in range(len(test)):
#预测类别
pre_lable = knn(test[i],train,k)
if pre_lable != int(test[i][4]):
error += 1
return float(error/len(test))
if __name__ == '__main__':
t_data = []
#数据预处理
with open('iris.txt', 'r') as data_txt:
data = data_txt.readlines()
for line in data:
temp = line.split(',')
t_data.append([float(temp[0]),float(temp[1]),float(temp[2]),float(temp[3]),int(temp[4])])
t_data = np.array(t_data)
#最优错误率
minerror = 1.0
#y存放错误率数组
y = []
for j in range(120):
newerror = cross_validate(t_data,j+1)
if minerror > newerror:
minerror = newerror
mink = j+1
y.append(newerror)
x = range(1,121)
#画折线图
plt.plot(x,y,linewidth=3,color='r')
print("最小错误率为:"+str(minerror)+"此时k值为:"+str(mink))
知识点:
1、画折线图
import matplotlib.pyplot as plt
y = []
x = range(1,121)
plt.plot(x,y,linewidth=3,color='r')
2、append()方法是指在列表末尾增加一个数据项。格式为list.append(***)
extend()方法是指在列表末尾增加一个数据集合。格式为list.extend(anotherlist)
insert()方法是指在某个特定位置前面增加一个数据项。格式为list.insert(index,***)
3、yield关键字(还不太懂)
yield 的作用就是把一个函数变成一个 generator,带有 yield 的函数不再是一个普通函数,Python 解释器会将其视为一个 generator,调用 fab(5) 不会执行 fab 函数,而是返回一个 iterable 对象!在 for 循环执行时,每次循环都会执行 fab 函数内部的代码,执行到 yield b 时,fab 函数就返回一个迭代值,下次迭代时,代码从 yield b 的下一条语句继续执行,而函数的本地变量看起来和上次中断执行前是完全一样的,于是函数继续执行,直到再次遇到 yield。