load digit.mat X T; [d,m,c]=size(X); X=reshape(X,[d m*c]);
Y=reshape(repmat([1:c],[m 1]),[1 m*c]);
ks=[1:10]; t=5; v=mod(randperm(m*c),t)+1;
for i=1:t
Yh=knn(X(:,v~=i),Y(v~=i),X(:,v==i),ks);
s(i,:)=mean(Yh~=repmat(Y(v==i),[length(ks) 1]),2);
end
[dum,a]=min(mean(s)); k=ks(a); [d,r,c]=size(T);
T=reshape(T,[d r*c]); U=reshape(knn(X,Y,T,k),[r c]);
for i=1:c, C(:,i)=sum(U==i); end, C, sum(diag(C))/sum(sum(C))
function U=knn(X,Y,T,ks)
m=size(T,2); D2=repmat(sum(T.^2,1),[size(X,2) 1]);
D2=D2+repmat(sum(X.^2,1)',[1 m])-2*X'*T; [dum,z]=sort(D2,1);
for i=1:length(ks)
k=ks(i);
for j=1:m
Z=sort(Y(z(1:k,j))); g=find([1 Z(1:end-1)~=Z(2:end)]);
[dum,a]= max([g(2:end) k+1]-g); U(i,j)=Z(g(a));
end, end
结果C:
200 0 0 0 0 0 0 0 0 0
0 193 1 0 0 0 1 4 1 0
0 0 195 0 3 0 0 1 1 0
0 0 0 191 1 2 0 0 6 0
0 3 4 0 187 0 1 1 2 2
0 2 0 0 2 195 0 0 0 1
0 0 1 2 0 0 192 2 3 0
0 1 4 1 3 0 0 186 2 3
0 0 0 3 0 0 1 1 195 0
0 1 1 0 0 0 0 0 0 198
以下是每行代码的含义解释:
1. `load digit.mat X T;`: 加载名为 `digit.mat` 的文件,并将其中的变量 `X` 和 `T` 读入当前工作空间。`X` 和 `T` 分别表示输入数据和目标标签。
2. `[d,m,c]=size(X);`: 获取矩阵 `X` 的维度,并将其分别赋值给变量 `d`、`m` 和 `c`。`d` 表示特征的维度,`m` 表示每个样本的数量,`c` 表示类别的数量。
3. `X=reshape(X,[d m*c]);`: 将矩阵 `X` 重塑为 `d x (m*c)` 的大小,以便将每个样本展平为一列。
4. `Y=reshape(repmat([1:c],[m 1]),[1 m*c]);`: 创建标签向量 `Y`,其中包含重复的类别标签,使其与重塑后的 `X` 矩阵的列数相匹配。
5. `ks=[1:10];`: 创建一个包含整数 1 到 10 的向量,表示用于 k 近邻算法的 k 值。
6. `t=5;`: 设定变量 `t` 的值为 5,表示将数据分为 5 个子集。
7. `v=mod(randperm(m*c),t)+1;`: 生成一个长度为 `m*c` 的随机排列向量,并对其进行模运算,将其值限定在 1 到 `t` 之间,作为样本的子集索引。
8. `for i=1:t`: 开始一个循环,用于依次处理每个子集。
9. `Yh=knn(X(:,v~=i),Y(v~=i),X(:,v==i),ks);`: 调用 `knn` 函数,利用 k 近邻算法对测试数据进行分类,并将预测的类别标签存储在 `Yh` 中。
10. `s(i,:)=mean(Yh~=repmat(Y(v==i),[length(ks) 1]),2);`: 计算预测结果 `Yh` 与实际标签 `Y` 的不匹配率,并将其存储在矩阵 `s` 的第 `i` 行。
11. `[dum,a]=min(mean(s)); k=ks(a);`: 找到具有最小平均不匹配率的行索引,并从 `ks` 中获取对应的 k 值,赋值给变量 `k`。
12. `[d,r,c]=size(T);`: 获取矩阵 `T` 的维度,并将其分别赋值给变量 `d`、`r` 和 `c`。
13. `T=reshape(T,[d r*c]);`: 将矩阵 `T` 重塑为 `d x (r*c)` 的大小,以便将每个测试样本展平为一列。
14. `U=reshape(knn(X,Y,T
,k),[r c]);`: 调用 `knn` 函数,利用 k 近邻算法对测试数据 `T` 进行分类,并将分类结果 `U` 重塑为 `r x c` 的大小。
15. `for i=1:c, C(:,i)=sum(U==i); end, C, sum(diag(C))/sum(sum(C))`: 对每个类别计算在测试集中的分类结果,并计算准确率。
function U=knn(X,Y,T,ks)的理解:
这段代码定义了一个名为 `knn` 的函数,实现了 k 近邻算法的分类过程。以下是对该函数的详细解释:
- `function U=knn(X,Y,T,ks)`:定义了一个函数 `knn`,该函数接受四个输入参数 `X`、`Y`、`T` 和 `ks`,并返回一个输出变量 `U`。
- `m=size(T,2);`:获取测试数据 `T` 的列数,即测试样本的数量。
- `D2=repmat(sum(T.^2,1),[size(X,2) 1]);`:计算测试样本与训练样本之间的欧氏距离的平方,并重复该结果,形成一个与训练样本数量相同的矩阵。
- `D2=D2+repmat(sum(X.^2,1)',[1 m])-2*X'*T;`:计算测试样本与训练样本之间的欧氏距离的平方,并将其加到先前计算的平方距离矩阵 `D2` 上。
- `[dum,z]=sort(D2,1);`:对距离矩阵 `D2` 的每一列进行排序,并返回排序后的距离矩阵 `D2` 和排序的索引矩阵 `z`。
- `for i=1:length(ks)`:对 `ks` 中的每个 k 值进行循环。
- `k=ks(i);`:获取当前循环的 k 值。
- `for j=1:m`:对测试样本中的每个样本进行循环。
- `Z=sort(Y(z(1:k,j)));`:根据排序索引矩阵 `z`,获取距离最近的 k 个训练样本的类别标签,并对它们进行排序。
- `g=find([1 Z(1:end-1)~=Z(2:end)]);`:找到排序后的类别标签中不同类别之间的分界点。
- `[dum,a]= max([g(2:end) k+1]-g);`:找到最大的分界点,该分界点之前的标签即为最终预测的类别。
- `U(i,j)=Z(g(a));`:将预测的类别赋值给输出变量`U`。
整体而言,该 `knn` 函数通过计算测试样本与训练样本之间的欧氏距离,根据最近的 k 个训练样本的类别标签进行预测,最终输出预测结果矩阵 `U`。