机器学习之KNN(matlab)

例子:手写识别数字(网上有很多资源可以下载)

1)主程序:

clear all;

close all;

global ImageRow ImageCol TrainNum TestNum;

ImageRow=28;

ImageCol=28;

TrainNum=2000;

TestNum=900;

k=8;

 

TrainData=LoadMNISTImages('train-images.idx3-ubyte');

TrainData=TrainData(:,1:TrainNum);

TrainLabel=LoadMNISTLabels('train-labels.idx1-ubyte');

TrainLabel=TrainLabel(1:TrainNum);

TestData=LoadMNISTImages('t10k-images.idx3-ubyte');

TestData=TestData(:,1:TestNum);

TestLabel=LoadMNISTLabels('t10k-labels.idx1-ubyte');

TestLabel=TestLabel(1:TestNum);

 

PredictLabel=knn(TrainData,TrainLabel,TestData,k);

accuracy=sum(PredictLabel==TestLabel)/TestNum;

disp(['准确率是:',num2str(accuracy*100),'%']);



2)k n n.m

function PredictLabel=knn(dataX,LabelX,dataY,k)

global ImageRow ImageCol TrainNum TestNum;

 

PredictLabel=zeros(TestNum,1);

 

for i=1:TestNum

    differ=sqrt(sum(((dataX-repmat(dataY(:,i),1,TrainNum)).^2),1));%计算欧式距离

    [p n]=sort(differ,'ascend');%距离从大到小排序

    PredictLabel(i)=mode(LabelX(n(1:k)));%将k个最近邻中标签最多的类做为预测结果

end

end



3)LoadMNISTImages.m

function images = LoadMNISTImages(filename)

%loadMNISTImages returns a 28x28x[number of MNIST images] matrix containing

%the raw MNIST images

 

fp = fopen(filename, 'rb');

assert(fp ~= -1, ['Could not open ', filename, '']);

 

magic = fread(fp, 1, 'int32', 0, 'ieee-be');

assert(magic == 2051, ['Bad magic number in ', filename, '']);

 

numImages = fread(fp, 1, 'int32', 0, 'ieee-be');

numRows = fread(fp, 1, 'int32', 0, 'ieee-be');

numCols = fread(fp, 1, 'int32', 0, 'ieee-be');

 

images = fread(fp, inf, 'unsigned char');

images = reshape(images, numCols, numRows, numImages);

images = permute(images,[2 1 3]);

 

fclose(fp);

 

% Reshape to #pixels x #examples

images = reshape(images, size(images, 1) * size(images, 2), size(images, 3));

% Convert to double and rescale to [0,1]

images = double(images) / 255;

 

end



4)LoadMNISTLabels.m

function labels = loadMNISTLabels(filename)

%loadMNISTLabels returns a [number of MNIST images]x1 matrix containing

%the labels for the MNIST images

 

fp = fopen(filename, 'rb');

assert(fp ~= -1, ['Could not open ', filename, '']);

 

magic = fread(fp, 1, 'int32', 0, 'ieee-be');

assert(magic == 2049, ['Bad magic number in ', filename, '']);

 

numLabels = fread(fp, 1, 'int32', 0, 'ieee-be');

 

labels = fread(fp, inf, 'unsigned char');

 

assert(size(labels,1) == numLabels, 'Mismatch in label count');

 

fclose(fp);

 

end



阅读更多
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页

加入CSDN,享受更精准的内容推荐,与500万程序员共同成长!
关闭
关闭