MATLAB实现kNN分类
简介
Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据样本,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。
k近邻学习是一种常用的监督学习方法,其工作机制非常简单:给定测试样本,基于某种距离度量找出训练集中与其最靠近的k个训练样本,然后基于这k个“邻居”的信息来进行预测1。
代码
- 加载数据
iris = xlsread('iris.csv');
使用xlsread可以读取表格中的数据,类型为mat.
- 打标签
iris(1:50,1) = 1;
iris(51:100,1) = 2;
iris(101:150,1) = 3;
Setosa标记为“1”,Versicolour标记为“2”,Virginica标记为“3”.
- 打乱数据顺序
randIndex = randperm(size(iris,1));
iris = iris(randIndex,:);
- 划分测试集和训练集
train = iris(1:100,:);
test = iris(101:150,:);
取100个数据为训练集,50个数据为测试集
- 抹除测试集标签
testLabelless = test(:,2:end);
- k近邻学习
k = 3;
predict = zeros(50,1);
for i=1:50
testTemp = ones(100,1) * testLabelless(i,:); %构造与训练数据相同形状的测试矩阵
diff = testTemp-train(:,2:end);%测试矩阵与训练矩阵作差
dist = sqrt(sum(diff.^2, 2));%获得测试点与每个训练点的欧氏距离
[sortDist, sortIndex] = sort(dist);%将距离排序
neighbors = train(sortIndex(1:k),:);%获取k个“邻居”
predict(i) = mode(neighbors(:,1));%“邻居”最多的类别为测试点类别
end
- 计算预测的准确率
err = bitxor(predict, test(:,1));
accuracy = 1-mean(err);
周志华. (2016). 机器学习. 清华大学出版社, 北京. ↩︎