matlab自带常见分类器的使用

 

目前了解到的 MATLAB 中分类器有: K 近邻分类器,随机森林分类器,朴素贝叶斯,集成学习方法,鉴别分析分类器,支持向量机。现将其主要函数使用方法总结如下,更多细节需参考 MATLAB 帮助文件。

  训练样本: train_data             % 矩阵,每行一个样本,每列一个特征
  训练样本标签: train_label       % 列向量
  测试样本: test_data
  测试样本标签: test_label
 
K 近邻分类器  ( KNN )
mdl = ClassificationKNN.fit(train_data,train_label,'NumNeighbors',1);
predict_label   =       predict(mdl, test_data);
accuracy         =       length(find(predict_label == test_label))/length(test_label)*100
               
 
随机森林分类器( Random Forest )
B = TreeBagger(nTree,train_data,train_label);
predict_label = predict(B,test_data);
 
 
朴素贝叶斯  ( Naive Bayes )
nb = NaiveBayes.fit(train_data, train_label);
predict_label   =       predict(nb, test_data);
accuracy         =       length(find(predict_label == test_label))/length(test_label)*100;
 
 
集成学习方法( Ensembles for Boosting, Bagging, or Random Subspace )
ens = fitensemble(train_data,train_label,'AdaBoostM1' ,100,'tree','type','classification');
predict_label   =       predict(ens, test_data);
 
 
鉴别分析分类器( discriminant analysis classifier )
obj = ClassificationDiscriminant.fit(train_data, train_label);
predict_label   =       predict(obj, test_data);
 
 
支持向量机( Support Vector Machine, SVM )
SVMStruct = svmtrain(train_data, train_label);

predict_label  = svmclassify(SVMStruct, test_data)

 

我自己代码如下:

 

 
  1. clc

  2. clear all

  3. load('wdtFeature');

  4.  
  5. %   训练样本:train_data % 矩阵,每行一个样本,每列一个特征

  6. %   训练样本标签:train_label % 列向量

  7. %   测试样本:test_data

  8. %   测试样本标签:test_label

  9. train_data = traindata'

  10. train_label = trainlabel'

  11. test_data = testdata'

  12. test_label = testlabel'

  13. % K近邻分类器 (KNN)

  14. % mdl = ClassificationKNN.fit(train_data,train_label,'NumNeighbors',1);

  15. % predict_label = predict(mdl, test_data);

  16. % accuracy = length(find(predict_label == test_label))/length(test_label)*100

  17. %

  18. % 94%

  19. % 随机森林分类器(Random Forest)

  20. % nTree = 5

  21. % B = TreeBagger(nTree,train_data,train_label);

  22. % predict_label = predict(B,test_data);

  23. %

  24. % m=0;

  25. % n=0;

  26. % for i=1:50

  27. % if predict_label{i,1}>0

  28. % m=m+1;

  29. % end

  30. % if predict_label{i+50,1}<0

  31. % n=n+1;

  32. % end

  33. % end

  34. %

  35. % s=m+n

  36. % r=s/100

  37.  
  38. % result 50%

  39.  
  40. % **********************************************************************

  41. % 朴素贝叶斯 (Na?ve Bayes)

  42. % nb = NaiveBayes.fit(train_data, train_label);

  43. % predict_label = predict(nb, test_data);

  44. % accuracy = length(find(predict_label == test_label))/length(test_label)*100;

  45. %

  46. %

  47. % % 结果 81%

  48. % % **********************************************************************

  49. % % 集成学习方法(Ensembles for Boosting, Bagging, or Random Subspace)

  50. % ens = fitensemble(train_data,train_label,'AdaBoostM1' ,100,'tree','type','classification');

  51. % predict_label = predict(ens, test_data);

  52. %

  53. % m=0;

  54. % n=0;

  55. % for i=1:50

  56. % if predict_label(i,1)>0

  57. % m=m+1;

  58. % end

  59. % if predict_label(i+50,1)<0

  60. % n=n+1;

  61. % end

  62. % end

  63. %

  64. % s=m+n

  65. % r=s/100

  66.  
  67. % 结果 97%

  68. % **********************************************************************

  69. % 鉴别分析分类器(discriminant analysis classifier)

  70. % obj = ClassificationDiscriminant.fit(train_data, train_label);

  71. % predict_label = predict(obj, test_data);

  72. %

  73. % m=0;

  74. % n=0;

  75. % for i=1:50

  76. % if predict_label(i,1)>0

  77. % m=m+1;

  78. % end

  79. % if predict_label(i+50,1)<0

  80. % n=n+1;

  81. % end

  82. % end

  83. %

  84. % s=m+n

  85. % r=s/100

  86. % result 86%

  87. % **********************************************************************

  88. % 支持向量机(Support Vector Machine, SVM)

  89. SVMStruct = svmtrain(train_data, train_label);

  90. predict_label = svmclassify(SVMStruct, test_data)

  91. m=0;

  92. n=0;

  93. for i=1:50

  94. if predict_label(i,1)>0

  95. m=m+1;

  96. end

  97. if predict_label(i+50,1)<0

  98. n=n+1;

  99. end

  100. end

  101.  
  102. s=m+n

  103. r=s/100

  104.  
  105. % result 86%

Fisher线性分类器是一种经典的线性分类器,它可以在二分类问题中有效地分类样本。其基本思想是将样本投影到一条直线上,使得同类样本的投影点尽可能地接近,异类样本的投影点尽可能地分开。下面是使用Matlab进行Fisher线性分类器实验的基本步骤: 1. 数据准备 首先,需要准备二分类样本数据,可以使用Matlab中的自带数据集,如digitDataset或fisheriris,也可以自己生成数据。这里以digitDataset为例,使用以下代码加载数据集: ```matlab digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ... 'nndatasets','DigitDataset'); digitData = imageDatastore(digitDatasetPath, ... 'IncludeSubfolders',true,'LabelSource','foldernames'); ``` 2. 特征提取 对于图像数据,需要进行特征提取,将图像转换为特征向量。在这里,可以使用常见的特征提取方法,如HOG、LBP等。这里以HOG为例,使用以下代码提取特征: ```matlab cellSize = [4 4]; hogFeatureSize = 36; digitData.ReadFcn = @(filename)readAndPreprocessImage(filename,cellSize,hogFeatureSize); trainingNumFiles = 750; rng(1) % For reproducibility [trainDigitData,testDigitData] = splitEachLabel(digitData,trainingNumFiles,'randomize'); ``` 其中,readAndPreprocessImage是自定义函数,用于读取图像并提取HOG特征。splitEachLabel函数用于将数据集分为训练集和测试集。 3. Fisher线性分类器训练 使用fitcdiscr函数进行Fisher线性分类器训练,代码如下: ```matlab % Extract HOG features from the training set trainingFeatures = zeros(size(trainDigitData.Files,1),hogFeatureSize); for i = 1:size(trainDigitData.Files,1) img = read(trainDigitData); trainingFeatures(i,:) = extractHOGFeatures(img,'CellSize',cellSize); end % Train a classifier faceClassifier = fitcdiscr(trainingFeatures,trainDigitData.Labels); ``` 其中,extractHOGFeatures函数用于提取HOG特征,fitcdiscr函数用于训练Fisher线性分类器。 4. Fisher线性分类器测试 使用predict函数进行Fisher线性分类器测试,代码如下: ```matlab % Extract HOG features from the test set testFeatures = zeros(size(testDigitData.Files,1),hogFeatureSize); for i = 1:size(testDigitData.Files,1) img = read(testDigitData); testFeatures(i,:) = extractHOGFeatures(img,'CellSize',cellSize); end % Test the classifier predictedLabels = predict(faceClassifier,testFeatures); ``` 其中,predictedLabels为预测的标签。 5. 分类器性能评估 使用confusionmat函数计算混淆矩阵,并计算分类器的准确率、精确率、召回率等性能指标,代码如下: ```matlab % Compute confusion matrix confMat = confusionmat(testDigitData.Labels,predictedLabels); % Compute accuracy accuracy = sum(diag(confMat))/sum(confMat(:)); % Compute precision precision = diag(confMat)./sum(confMat,2); % Compute recall recall = diag(confMat)./sum(confMat,1)'; ``` 其中,accuracy为分类器的准确率,precision为分类器的精确率,recall为分类器的召回率。 以上就是使用Matlab进行Fisher线性分类器实验的基本步骤。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

点云实验室lab

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值