机器学习1——KNN

KNN算法学习
matlab代码1

%% KNN
clear all
clc
%% data
trainData = [1.0,2.0;1.2,0.1;0.1,1.4;0.3,3.5];
trainClass = [1,1,2,2];
testData = [0.5,2.3];
k = 3;

%% distance
row = size(trainData,1);
col = size(trainData,2);
test = repmat(testData,row,1);
dis = zeros(1,row);
for i = 1:row
    diff = 0;
    for j = 1:col
        diff = diff + (test(i,j) - trainData(i,j)).^2;
    end
    dis(1,i) = diff.^0.5;
end

%% sort
jointDis = [dis;trainClass];
sortDis= sortrows(jointDis');
sortDisClass = sortDis';

%% find
class = sort(2:1:k);
member = unique(class);
num = size(member);

max = 0;
for i = 1:num
    count = find(class == member(i));
    if count > max
        max = count;
        label = member(i);
    end
end

disp('最终的分类结果为:');
fprintf('%d\n',label)

运行之后的结果是,最终的分类结果为:2。和预期结果一样。

matlab代码2

function y = knn(X, X_train, y_train, K)
%KNN k-Nearest Neighbors Algorithm.
%
%   INPUT:  X:         testing sample features, P-by-N_test matrix.
%           X_train:   training sample features, P-by-N matrix.
%           y_train:   training sample labels, 1-by-N row vector.
%           K:         the k in k-Nearest Neighbors
%
%   OUTPUT: y    : predicted labels, 1-by-N_test row vector.
%
% Author: Ren Kan
[~,N_test] = size(X);

predicted_label = zeros(1,N_test);
for i=1:N_test
    [dists, neighbors] = top_K_neighbors(X_train,y_train,X(:,i),K); 
    % calculate the K nearest neighbors and the distances.
    predicted_label(i) = recog(y_train(neighbors),max(y_train));
    % recognize the label of the test vector.
end

y = predicted_label;

end

查找最近K近邻的部分代码:


function [dists,neighbors] = top_K_neighbors( X_train,y_train,X_test,K )
% Author: Ren Kan
%   Input: 
%   X_test the test vector with P*1
%   X_train and y_train are the train data set
%   K is the K neighbor parameter
[~, N_train] = size(X_train);
test_mat = repmat(X_test,1,N_train);
dist_mat = (X_train-double(test_mat)) .^2;
% The distance is the Euclid Distance.
dist_array = sum(dist_mat);
[dists, neighbors] = sort(dist_array);
% The neighbors are the index of top K nearest points.
dists = dists(1:K);
neighbors = neighbors(1:K);

end

利用概率求解测试集预测标签部分代码:


function result = recog( K_labels,class_num )
%RECOG Summary of this function goes here
%   Author: Ren Kan
[~,K] = size(K_labels);
class_count = zeros(1,class_num+1);
for i=1:K
    class_index = K_labels(i)+1; % +1 is to avoid the 0 index reference.
    class_count(class_index) = class_count(class_index) + 1;
end
[~,result] = max(class_count);
result = result - 1; % Do not forget -1 !!!

end

应用可以有:
手写体识别、数字验证码识别等。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值