# matlab svmtrain和svmclassify函数使用示例



1. clear; close all; clc;

2. %% ================ load fisheriris.mat ================

3. load fisheriris.mat

1、对于线性分类问题，我们选取线性核函数，原始数据包括训练数据和测试数据两部分。



1. data = meas(51:end,3:4); % column 3,column 4作为特征值

2. group = species(51:end); % 类别

3. idx = randperm(size(data,1));

4. N = length(idx);

5.
6. % SVM train

7. T = floor(N*0.9); % 90组数据作为训练数据

8. xdata = data(idx(1:T),:);

9. xgroup = group(idx(1:T));

10. svmStr = svmtrain(xdata,xgroup,'Showplot',true);



1. % SVM predict

2. P = floor(N*0.1); % 10组预测数据

3. ydata = data(idx(T+1:end),:);

4. ygroup = group(idx(T+1:end));

5. pgroup = svmclassify(svmStr,ydata,'Showplot',true); % svm预测

6. hold on;

7. plot(ydata(:,1),ydata(:,2),'bs','Markersize',12);

8. accuracy1 = sum(strcmp(pgroup,ygroup))/P*100; % 预测准确性

9. hold off;

2、对于非线性分类问题，我们选取高斯核函数RBF，原始数据包括训练数据和测试数据两部分。



1. data = meas(51:end,1:2); % column 1,column 2作为特征值

2. group = species(51:end); % 类别

3. idx = randperm(size(data,1));

4. N = length(idx);

5.
6. % SVM train

7. T = floor(N*0.9); % 90组数据作为训练数据

8. xdata = data(idx(1:T),:);

9. xgroup = group(idx(1:T));



1. % different sigma

2. figure;

3. sigma = 0.5;

4. svmStr = svmtrain(xdata,xgroup,'kernel_function','rbf','rbf_sigma',...

5. sigma,'showplot',true);

6. title('sigma = 0.5');

7. figure;

8. sigma = 1;

9. svmStr = svmtrain(xdata,xgroup,'kernel_function','rbf','rbf_sigma',...

10. sigma,'showplot',true);

11. title('sigma = 1');

12. figure;

13. sigma = 3;

14. svmStr = svmtrain(xdata,xgroup,'kernel_function','rbf','rbf_sigma',...

15. sigma,'showplot',true);

16. title('sigma = 3');



1. % different C

2. figure;

3. C = 1;

4. svmStr = svmtrain(xdata,xgroup,'kernel_function','rbf','boxconstraint',...

5. C,'showplot',true);

6. title('C = 0.1');

7. figure;

8. C = 8;

9. svmStr = svmtrain(xdata,xgroup,'kernel_function','rbf','boxconstraint',...

10. C,'showplot',true);

11. title('C = 1');

12. figure;

13. C = 64;

14. svmStr = svmtrain(xdata,xgroup,'kernel_function','rbf','boxconstraint',...

15. C,'showplot',true);

16. title('C = 10');



1. % SVM predict

2. P = floor(N*0.1); % 10组预测数据

3. ydata = data(idx(T+1:end),:);

4. ygroup = group(idx(T+1:end));

5. % sigma = 1,C = 1,default

6. figure;

7. svmStr = svmtrain(xdata,xgroup,'kernel_function','rbf','showplot',true);

8. pgroup = svmclassify(svmStr,ydata,'Showplot',true); % svm预测

9. hold on;

10. plot(ydata(:,1),ydata(:,2),'bs','Markersize',12);

11. accuracy2 = sum(strcmp(pgroup,ygroup))/P*100; % 预测准确性

12. hold off;