KNN

function outClass = knnclassify(sample, TRAIN, group, K, distance,rule)
%KNNCLASSIFY classifies data using the nearest-neighbor method
%
%   CLASS = KNNCLASSIFY(SAMPLE,TRAINING,GROUP) classifies each row of the
%   data in SAMPLE into one of the groups in TRAINING using the nearest-
%   neighbor method. SAMPLE and TRAINING must be matrices with the same
%   number of columns. GROUP is a grouping variable for TRAINING. Its
%   unique values define groups, and each element defines the group to
%   which the corresponding row of TRAINING belongs. GROUP can be a
%   numeric vector, a string array, or a cell array of strings. TRAINING
%   and GROUP must have the same number of rows. CLASSIFY treats NaNs or
%   empty strings in GROUP as missing values and ignores the corresponding
%   rows of TRAINING. CLASS indicates which group each row of SAMPLE has
%   been assigned to, and is of the same type as GROUP.
%
%   CLASS = KNNCLASSIFY(SAMPLE,TRAINING,GROUP,K) allows you to specify K,
%   the number of nearest neighbors used in the classification. The default
%   is 1.
%
%   CLASS = KNNCLASSIFY(SAMPLE,TRAINING,GROUP,K,DISTANCE) allows you to
%   select the distance metric. Choices are
%             'euclidean'    Euclidean distance (default)
%             'cityblock'    Sum of absolute differences, or L1
%             'cosine'       One minus the cosine of the included angle
%                            between points (treated as vectors)
%             'correlation'  One minus the sample correlation between
%                            points (treated as sequences of values)
%             'Hamming'      Percentage of bits that differ (only
%                            suitable for binary data)
%
%   CLASS = KNNCLASSIFY(SAMPLE,TRAINING,GROUP,K,DISTANCE,RULE) allows you
%   to specify the rule used to decide how to classify the sample. Choices
%   are:
%             'nearest'   Majority rule with nearest point tie-break
%             'random'    Majority rule with random point tie-break
%             'consensus' Consensus rule
%
%   The default behavior is to use majority rule. That is, a sample point
%   is assigned to the class from which the majority of the K nearest
%   neighbors are from. Use 'consensus' to require a consensus, as opposed
%   to majority rule. When using the consensus option, points where not all
%   of the K nearest neighbors are from the same class are not assigned
%   to one of the classes. Instead the output CLASS for these points is NaN
%   for numerical groups or '' for string named groups. When classifying to
%   more than two groups or when using an even value for K, it might be
%   necessary to break a tie in the number of nearest neighbors. Options
%   are 'random', which selects a random tiebreaker, and 'nearest', which
%   uses the nearest neighbor among the tied groups to break the tie. The
%   default behavior is majority rule, nearest tie-break.
%
%   Examples:
%
%      % training data: two normal components
%      training = [mvnrnd([ 1  1],   eye(2), 100); ...
%                  mvnrnd([-1 -1], 2*eye(2), 100)];
%      group = [ones(100,1); 2*ones(100,1)];
%      gscatter(training(:,1),training(:,2),group);hold on;
%
%      % some random sample data
%      sample = unifrnd(-5, 5, 100, 2);
%      % classify the sample using the nearest neighbor classification
%      c = knnclassify(sample, training, group);
%
%      gscatter(sample(:,1),sample(:,2),c,'mc'); hold on;
%      c3 = knnclassify(sample, training, group, 3);
%      gscatter(sample(:,1),sample(:,2),c3,'mc','o');
%
%   See also CLASSIFY, CLASSPERF, CROSSVALIND, KNNIMPUTE, SVMCLASSIFY,
%   SVMTRAIN.

%   Copyright 2004-2008 The MathWorks, Inc.


%   References:
%     [1] Machine Learning, Tom Mitchell, McGraw Hill, 1997

bioinfochecknargin(nargin,3,mfilename)

% grp2idx sorts a numeric grouping var ascending, and a string grouping
% var by order of first occurrence
[gindex,groups] = grp2idx(group);
nans = find(isnan(gindex));
if ~isempty(nans)
    TRAIN(nans,:) = [];
    gindex(nans) = [];
end
ngroups = length(groups);

[n,d] = size(TRAIN);
if size(gindex,1) ~= n
    error(message('bioinfo:knnclassify:BadGroupLength'));
elseif size(sample,2) ~= d
    error(message('bioinfo:knnclassify:SampleTrainingSizeMismatch'));
end
m = size(sample,1);

if nargin < 4
    K = 1;
elseif ~isnumeric(K)
    error(message('bioinfo:knnclassify:KNotNumeric'));
end
if ~isscalar(K)
    error(message('bioinfo:knnclassify:KNotScalar'));
end

if K<1
    error(message('bioinfo:knnclassify:KLessThanOne'));
end

if isnan(K)
    error(message('bioinfo:knnclassify:KNaN'));
end

if nargin < 5 || isempty(distance)
    distance  = 'euclidean';
elseif ischar(distance)
    distNames = {'euclidean','cityblock','cosine','correlation','hamming'};
    i = find(strncmpi(distance, distNames,numel(distance)));
    if length(i) > 1
        error(message('bioinfo:knnclassify:AmbiguousDistance', distance));
    elseif isempty(i)
        error(message('bioinfo:knnclassify:UnknownDistance', distance));
    end
    distance = distNames{i};
else
    error(message('bioinfo:knnclassify:InvalidDistance'));
end

if nargin < 6
    rule = 'nearest';
elseif ischar(rule)
    
    % lots of testers misspelled consensus.
    if strncmpi(rule,'conc',4)
        rule(4) = 's';
    end
    ruleNames = {'random','nearest','farthest','consensus'};
    i = find(strncmpi(rule, ruleNames,numel(rule)));
    % %   May need this if we add more rules and introduce the possibility of
    % %   ambiguity.
    %     if length(i) > 1
    %         error('bioinfo:knnclassify:AmbiguousRule', ...
    %             'Ambiguous ''Rule'' parameter value:  %s.', rule);
    %     else
    if isempty(i)
        error(message('bioinfo:knnclassify:UnknownRule', rule));
    end
    rule = ruleNames{i};
    %     end
else
    error(message('bioinfo:knnclassify:InvalidRule'));
end

% Calculate the distances from all points in the training set to all points
% in the test set.

if strncmpi(distance,'hamming',3)
        if ~all(ismember(sample(:),[0 1]))||~all(ismember(TRAIN(:),[0 1]))
            error(message('bioinfo:knnclassify:HammingNonBinary'));
        end
end
dIndex = knnsearch(TRAIN,sample,'distance', distance,'K',K);
% find the K nearest

if K >1
    classes = gindex(dIndex);
    % special case when we have one sample(test) point -- this gets turned into a
    % column vector, so we have to turn it back into a row vector.
    if size(classes,2) == 1
        classes = classes';
    end
    % count the occurrences of the classes
    
    counts = zeros(m,ngroups);
    for outer = 1:m
        for inner = 1:K
            counts(outer,classes(outer,inner)) = counts(outer,classes(outer,inner)) + 1;
        end
    end
    
    [L,outClass] = max(counts,[],2);
    
    % Deal with consensus rule
    if strcmp(rule,'consensus')
        noconsensus = (L~=K);
        
        if any(noconsensus)
            outClass(noconsensus) = ngroups+1;
            if isnumeric(group) || islogical(group)
                groups(end+1) = {'NaN'};
            else
                groups(end+1) = {''};
            end
        end
    else    % we need to check case where L <= K/2 for possible ties
        checkRows = find(L<=(K/2));
        
        for i = 1:numel(checkRows)
            ties = counts(checkRows(i),:) == L(checkRows(i));
            numTies = sum(ties);
            if numTies > 1
                choice = find(ties);
                switch rule
                    case 'random'
                        % random tie break
                        
                        tb = randsample(numTies,1);
                        outClass(checkRows(i)) = choice(tb);
                    case 'nearest'
                        % find the use the closest element of the equal groups
                        % to break the tie
                        for inner = 1:K
                            if ismember(classes(checkRows(i),inner),choice)
                                outClass(checkRows(i)) = classes(checkRows(i),inner);
                                break
                            end
                        end
                    case 'farthest'
                        % find the use the closest element of the equal groups
                        % to break the tie
                        for inner = K:-1:1
                            if ismember(classes(checkRows(i),inner),choice)
                                outClass(checkRows(i)) = classes(checkRows(i),inner);
                                break
                            end
                        end
                end
            end
        end
    end
    
else
    outClass = gindex(dIndex);
end

% Convert back to original grouping variable
if isa(group,'categorical')
    labels = getlabels(group);
    if isa(group,'nominal')
        groups = nominal(groups,[],labels);
    else
        groups = ordinal(groups,[],getlabels(group));
    end
    outClass = groups(outClass);
elseif isnumeric(group) || islogical(group)
    groups = str2num(char(groups)); %#ok
    outClass = groups(outClass);
elseif ischar(group)
    groups = char(groups);
    outClass = groups(outClass,:);
else %if iscellstr(group)
    outClass = groups(outClass);
end

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值