
本文介绍了K-Nearest Neighbor (KNN) 分类算法的MATLAB实现,包括cvKnn函数的详细说明,以及如何使用该函数进行分类和距离计算。还提供了一个简单的KNN应用示例,展示了原型向量和分类向量的二维可视化。


% cvEucdist - Euclidean distance

% Synopsis
%   [d] = cvEucdist(X, Y)
% Description
%   cvEucdist calculates a squared euclidean distance between X and Y.
% Inputs ([]s are optional)
%   (matrix) X        D x N matrix where D is the dimension of vectors
%                     and N is the number of vectors.
%   (matrix) [Y]      D x P matrix where D is the dimension of vectors
%                     and P is the number of vectors.
%                     If Y is not given, the L2 norm of X is computed and
%                     1 x N matrix (not N x 1) is returned.
% Outputs ([]s are optional)
%   (matrix) d        N x P matrix where d(n,p) represents the squared
%                     euclidean distance between X(:,n) and Y(:,p).
% Examples
%   X = [1 2
%        1 2];
%   Y = [1 2 3
%        1 2 3];
%   d = cvEucdist(X, Y)
% %      0     2     8
% %      2     0     2
% See also
%   cvMahaldist

% Authors
%   Naotoshi Seo <sonots(at)sonots.com>
% License
%   The program is free to use for non-commercial academic purposes,
%   but for course works, you must understand what is going inside to use.
%   The program can be used, modified, or re-distributed for any purposes
%   if you or one of your group understand codes (the one must come to
%   court if court cases occur.) Please contact the authors if you are
%   interested in using the program without meeting the above conditions.
% Changes
%   06/2006  First Edition
function d = cvEucdist(X, Y)
 if ~exist('Y', 'var') || isempty(Y)
     %% Y = zeros(size(X, 1), 1);
     U = ones(size(X, 1), 1);
     d = abs(X'.^2*U).'; return;
 V = ~isnan(X); X(~V) = 0; % V = ones(D, N); 
 %clear V;
 U = ~isnan(Y); Y(~U) = 0; % U = ones(D, P); 
 %clear U;
 %d = abs(X'.^2*U - 2*X'*Y + V'*Y.^2);
 d1 = X'.^2*U;
 d3 = V'*Y.^2;
 d2 = X'*Y;
 d = abs(d1-2*d2+d3);
% X = X';
% Y = Y';
% for i=1:size(X,1)
%     for j=1:size(Y,1)
%         d(i,j)=(norm(X(i,:)-Y(j,:)))^2;  %计算每个测试样本与所有训练样本的欧氏距离
%     end

% end

% cvKnn - K-Nearest Neighbor classification
% Synopsis
%   [Class] = cvKnn(X, Proto, ProtoClass, [K], [distFunc])
% Description
%   K-Nearest Neighbor classification
% Inputs ([]s are optional)
%   (matrix) X        D x N matrix representing column classifiee vectors
%                     where D is the number of dimensions and N is the
%                     number of vectors.
%   (matrix) Proto    D x P matrix representing column prototype vectors
%                     where D is the number of dimensions and P is the
%                     number of vectors.
%   (vector) ProtoClass
%                     1 x P vector containing class lables for prototype
%                     vectors. 
%   (scalar) [K = 1]  K-NN's K. Search K nearest neighbors.
%   (func)   [distFunc = @cvEucdist]
%                     A function handle for distance measure. The function
%                     must have two arguments for matrix X and Y. See
%                     cvEucdist.m (Euclidean distance) as a reference.
% Outputs ([]s are optional)
%   (vector) Class    1 x N vector containing classified class labels 
%                     for X. Class(n) is the class id for X(:,n). 
%   (matrix) [Rank]   Available only for NN (K = 1) now.
%                     nClass x N vector containing ranking class labels
%                     for X. Rank(1,n) is the 1st candidate which is 
%                     the same with Class(n), Rank(2,n) is the 2nd 
%                     candidate, Rank(3,n) is the 3rd, and so on.
% See also
%   cvEucdist, cvMahaldist

% Authors
%   Naotoshi Seo <sonots(at)sonots.com>
% License
%   The program is free to use for non-commercial academic purposes,
%   but for course works, you must understand what is going inside to use.
%   The program can be used, modified, or re-distributed for any purposes
%   if you or one of your group understand codes (the one must come to
%   court if court cases occur.) Please contact the authors if you are
%   interested in using the program without meeting the above conditions.
% Changes
%   04/01/2005  First Edition
function [Class, Rank] = cvKnn(X, Proto, ProtoClass, K, distFunc)
if ~exist('K', 'var') || isempty(K)
    K = 1;%默认为K = 1
if ~exist('distFunc', 'var') || isempty(distFunc)
    distFunc = @cvEucdist;
if size(X, 1) ~= size(Proto, 1)
    error('Dimensions of classifiee vectors and prototype vectors do not match.');
[D, N] = size(X);

% Calculate euclidean distances between classifiees and prototypes
d = distFunc(X, Proto);

if K == 1, % sort distances only if K>1
    [mini, IndexProto] = min(d, [], 2); % 2 == row%每列的最小元素
    Class = ProtoClass(IndexProto);
    if nargout == 2, % instance indices in similarity descending order
        [sorted, ind] = sort(d'); % PxN
        RankIndex = ProtoClass(ind); %,e.g., [2 1 2 3 1 5 4 1 2]'
        % conv into, e.g., [2 1 3 5 4]'
        for n = 1:N
            [ClassLabel, ind] = unique(RankIndex(:,n),'first');
            [sorted, ind] = sort(ind);
            Rank(:,n) = ClassLabel(ind);
    [sorted, IndexProto] = sort(d'); % PxN
    clear d;
    % K closest
    IndexProto = IndexProto(1:K,:);
    KnnClass = ProtoClass(IndexProto);
    % Find all class labels
    ClassLabel = unique(ProtoClass);
    nClass = length(ClassLabel);
    for i = 1:nClass
        ClassCounter(i,:) = sum(KnnClass == ClassLabel(i));
    [maxi, winnerLabelIndex] = max(ClassCounter, [], 1); % 1 == col
    % Future Work: Handle ties somehow
    Class = ClassLabel(winnerLabelIndex);


function main
trainData = [
    0.6213    0.5226    0.9797    0.9568    0.8801    0.8757    0.1730    0.2714    0.2523
    0.7373    0.8939    0.6614    0.0118    0.1991    0.0648    0.2987    0.2844    0.4692
trainClass = [
    1     1     1     2     2     2     3     3     3
testData = [
    0.9883    0.5828    0.4235    0.5155    0.3340
    0.4329    0.2259    0.5798    0.7604    0.5298

% main
testClass = cvKnn(testData, trainData, trainClass);

% plot prototype vectors
classLabel = unique(trainClass);
nClass     = length(classLabel);
plotLabel = {'r*', 'g*', 'b*'};
for i=1:nClass
    A = trainData(:, trainClass == classLabel(i));
    plot(A(1,:), A(2,:), plotLabel{i});
    hold on;

% plot classifiee vectors
plotLabel = {'ro', 'go', 'bo'};
for i=1:nClass
    A = testData(:, testClass == classLabel(i));
    plot(A(1,:), A(2,:), plotLabel{i});
    hold on;
legend('1: prototype','2: prototype', '3: prototype', '1: classifiee', '2: classifiee', '3: classifiee', 'Location', 'NorthWest');
title('K nearest neighbor');
hold off;

