机器学习实验-Experiment 4:Naive Bayes朴素贝叶斯

综述

主要是简略记录一下经过,当时做实验的时候,实验报告其实写的要详细很多:但是凡事要自己亲自实验才有发言权,所以这些部分写的都较为简略。

朴素贝叶斯在机器学习课程中还是很重要的。思想也很美妙。
看似自然、简单、顺畅,但是一定要细心琢磨,理解
本文还简单介绍了我在实验中使用的混淆矩阵。一种更加直观的评价感受方案。回忆我在实验3中介绍的ROC曲线和AUC评价,希望可以打开新的思路。

使用朴素贝叶斯模型解决了实验问题并回答了相关问题:
1.推到并构建了多分类的朴素贝叶斯机器学习模型;
2.对比了该模型在不同的采样规模下的表现效果;
3.对模型进行了拉普拉斯光顺以避免特征未出现而引起的误差;
4.使用混淆矩阵(Confusion Matrix)对模型的表现进行了度量和评价。
老师在设计实验时,让我们分析了样本规模对于训练结果的影响,为了便于采样我给出了非重复采样的函数:
非重复采样函数:

function s = sampling(R, n)
% 选择抽样,R为记录集合,n为抽取的样本数
% 编写函数时用的测试数据
N = length(R);
t = 0;   % 处理过的记录总数
m = 0;   % 已选得的记录数
while 1
    U  = rand;
    if (N-t)*U < n-m
        m = m + 1;
        s(m) = R(t+1);
        % 若已抽取到足够的记录,则算法终止
        if m >= n, break, end
    end
    t = t + 1;
end

对于MATLAB-非重复采样输出采样和带标号采样可以看我之前的这篇文章
MATLAB-非重复采样输出采样和带标号采样

部分效果图:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

实验目的

1.深化学习并进一步掌握的朴素贝叶斯的原理以及实现方式;
2.体会贝叶斯分类器的设计思想和基本假设;
3.通过实践提升编程水平,提升通过动手解决实际问题的能力;
4.在实践中总结经验,回扣所学知识,巩固基础;
5.学习构建多分类贝叶斯模型,并对模型的分类效果进行评价和分析;

理论支持简述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

代码

朴素贝叶斯模型

close all
clc
clear
cont = [100,500,1000,2500,5000,7500,8000,10000];
 
laplace  = [3,5,4,4,3,2,3,3]
 
target   = 10000; %展示指定的数据采样结果
%本次实验
%需要说明的几点:
%1.我们采用了拉普拉斯光顺所以结果会有小的波动
%2.控制采样数目不是采用随机数而是采样
%存储不同的采样情
for ooo = 1:8
 
nums = cont(ooo); %控制训练数据的个数
 
 
a = 1;
b = 10000;  %设置随机数的取值范围
train = load('training_data.txt');
test = load('test_data.txt');
%这里我自写了一个采样的函数(不重叠采样)
%使用随机数是不安全的
if nums == 10000
else
    sams = sampling(1:10000,nums);
    train = train(sams,:);
end
%每次随机抽取100个例子
[m1,n] = size(train);
%m1是训练集的数目
[m2,n] = size(test);
%m2是测试集的数目
X = train(:,1:8);
Y = train(:,9);
X2 = test(:,1:8);
Y2 = test(:,9);
X2 = X2+1;
counter1 = 0;
counter2 = 0;
counter3 = 0;
counter4 = 0;
counter5 = 0;
%共计5类样本,记录不同类别的数量
count_1 = zeros(8,5);
count_2 = zeros(8,5);
count_3 = zeros(8,5);
count_4 = zeros(8,5);
count_5 = zeros(8,5);
%统计样本中属性的情况
%第1类中第i个属性取值为j的:数目
for i = 1:m1
    if(Y(i)==0)
        counter1 = counter1+1;
        for j = 1:8
            for k = 1:5
                if X(i,j) ==(k-1)
                    count_1(j,k) = count_1(j,k) +1;
                    break;
                end
            end  
        end
    elseif(Y(i)==1)
        counter2 = counter2+1;
        for j = 1:8
            for k = 1:5
                if X(i,j) ==(k-1)
                    count_2(j,k) = count_2(j,k) +1;
                    break;
                end
            end  
        end
    elseif(Y(i)==2)
        counter3 = counter3+1;
        for j = 1:8
            for k = 1:5
                if X(i,j) ==(k-1)
                    count_3(j,k) = count_3(j,k) +1;
                    break;
                end
            end  
        end
    elseif(Y(i)==3)
        counter4 = counter4+1;
        for j = 1:8
            for k = 1:5
                if X(i,j) ==(k-1)
                    count_4(j,k) = count_4(j,k) +1;
                    break;
                end
            end  
        end
    elseif(Y(i)==4)
        counter5 = counter5+1;
        for j = 1:8
            for k = 1:5
                if X(i,j) ==(k-1)
                    count_5(j,k) = count_5(j,k) +1;
                    break;
                end
            end  
        end
    end     
end
y1 = (counter1+1)/(m1+5);
y2 = (counter2+1)/(m1+5);
y3 = (counter3+1)/(m1+5);
y4 = (counter4+1)/(m1+5);
y5 = (counter5+1)/(m1+5);
 
for i  = 1:8
    for j  = 1:5
        y_1(i,j) = (1+count_1(i,j))/(counter1+laplace(i));
        y_2(i,j) = (count_2(i,j)+1)/(counter2+laplace(i));
        y_3(i,j) = (count_3(i,j)+1)/(counter3+laplace(i));
        y_4(i,j) = (count_4(i,j)+1)/(counter4+laplace(i));
        y_5(i,j) = (count_5(i,j)+1)/(counter5+laplace(i));
    end
end
right_counter  = 0;
rescon = zeros(5,1);
realcon = zeros(5,1);
for i  = 1:m2
    p1 = y1*y_1(1,X2(i,1))*y_1(2,X2(i,2))*y_1(3,X2(i,3))*y_1(4,X2(i,4))...
        *y_1(5,X2(i,5))*y_1(6,X2(i,6))*y_1(7,X2(i,7))*y_1(8,X2(i,8));
    p2 = y2*y_2(1,X2(i,1))*y_2(2,X2(i,2))*y_2(3,X2(i,3))*y_2(4,X2(i,4))...
        *y_2(5,X2(i,5))*y_2(6,X2(i,6))*y_2(7,X2(i,7))*y_2(8,X2(i,8));
    p3 = y3*y_3(1,X2(i,1))*y_3(2,X2(i,2))*y_3(3,X2(i,3))*y_3(4,X2(i,4))...
        *y_3(5,X2(i,5))*y_3(6,X2(i,6))*y_3(7,X2(i,7))*y_3(8,X2(i,8));
    p4 = y4*y_4(1,X2(i,1))*y_4(2,X2(i,2))*y_4(3,X2(i,3))*y_4(4,X2(i,4))...
        *y_4(5,X2(i,5))*y_4(6,X2(i,6))*y_4(7,X2(i,7))*y_4(8,X2(i,8));
    p5 = y5*y_5(1,X2(i,1))*y_5(2,X2(i,2))*y_5(3,X2(i,3))*y_5(4,X2(i,4))...
        *y_5(5,X2(i,5))*y_5(6,X2(i,6))*y_5(7,X2(i,7))*y_5(8,X2(i,8));
 
ans1=(p1)/(p1+p2+p3+p4+p5);
ans2=(p2)/(p1+p2+p3+p4+p5);
ans3=(p3)/(p1+p2+p3+p4+p5);
ans4=(p4)/(p1+p2+p3+p4+p5);
ans5=(p5)/(p1+p2+p3+p4+p5);
ans = [ans1,ans2,ans3,ans4,ans5];
[Y,I] = sort(ans);
rescon(I(5)) = rescon(I(5))+1;
realcon(Y2(i)+1) = realcon(Y2(i)+1) +1;
if(Y2(i)+1 == I(5))
    right_counter = right_counter+1;
end
predict(i) = I(5)-1;
end
rescon = rescon/m2;
realcon = realcon/m2;
x = [0;1;2;3;4];
y = [rescon,realcon];
ycon(:,ooo*2-1) = rescon;
ycon(:,ooo*2) = realcon;
disp('final res:')
disp(right_counter/m2);
disp('sample number:')
disp(nums)
accurancy(ooo) = right_counter/m2;
predict = predict';
res_cons(:,ooo) = predict;
 
if nums == target
figure 
bar(x,y)
grid on
axis([-inf inf 0 1]) 
legend('predict','realclass');
xlabel('Y class');
ylabel('ratio');
title('Final result');
figure
confusion_matrix(Y2,predict)
end
 
end
 
%cont = [100,500,1000,2500,5000,7500,8000,10000];
figure('units','normalized','position',[0.02,0.02,0.95,0.95])
subplot(2,4,1);
confusion_matrix(Y2,res_cons(:,1))
title('confusion matrix with train data 100')
subplot(2,4,2);
confusion_matrix(Y2,res_cons(:,2))
title('confusion matrix with train data 500')
subplot(2,4,3);
confusion_matrix(Y2,res_cons(:,3))
title('confusion matrix with train data 1000')
subplot(2,4,4);
confusion_matrix(Y2,res_cons(:,4))
title('confusion matrix with train data 2500')
subplot(2,4,5);
confusion_matrix(Y2,res_cons(:,5))
title('confusion matrix with train data 5000')
subplot(2,4,6);
confusion_matrix(Y2,res_cons(:,6))
title('confusion matrix with train data 7500')
subplot(2,4,7);
confusion_matrix(Y2,res_cons(:,7))
title('confusion matrix with train data 8000')
subplot(2,4,8);
confusion_matrix(Y2,res_cons(:,8))
title('confusion matrix with train data 10000');
ycon;
x = [0;1;2;3;4];
title_lab= ['100','500','1000','2500','5000','7500','8000','10000'];
figure('units','normalized','position',[0.02,0.02,0.95,0.95])
for i = 1:8
subplot(2,4,i);
y = [ycon(:,2*i-1),ycon(:,2*i)];
bar(x,y)
grid on
axis([-inf inf 0 1]) 
legend('predict','realclass');
xlabel('Y class');
ylabel('ratio');
title(['Final result: ',(num2str(cont(i)))]);
end
x = 1:8;
figure
plot(x,accurancy, '--b','Linewidth',2);
grid on
%axis([-inf inf 0 1])
set(gca,'XTickLabel',{'100','500','1000','2500','5000','7500','8000','10000'});
title('The accuracy of the Naive Bayes model with different size of samples');
xlabel('Size of trainning data')
ylabel('Accuracy')

结果绘制函数

close all
clc
clear
%本次实验
%需要说明的几点:
%1.我们采用了拉普拉斯光顺所以结果会有小的波动
%2.控制采样数目不是采用随机数而是采样
%存储不同的采样情
res_counter =1;
laplace  = [3,5,4,4,3,2,3,3]
for ooo = [100:100:10000]
 
nums = ooo; %控制训练数据的个数
 
 
a = 1;
b = 10000;  %设置随机数的取值范围
train = load('training_data.txt');
test = load('test_data.txt');
%这里我自写了一个采样的函数(不重叠采样)
%使用随机数是不安全的
if nums == 10000
else
    sams = sampling(1:10000,nums);
    train = train(sams,:);
end
%每次随机抽取100个例子
[m1,n] = size(train);
%m1是训练集的数目
[m2,n] = size(test);
%m2是测试集的数目
X = train(:,1:8);
Y = train(:,9);
X2 = test(:,1:8);
Y2 = test(:,9);
X2 = X2+1;
counter1 = 0;
counter2 = 0;
counter3 = 0;
counter4 = 0;
counter5 = 0;
%共计5类样本,记录不同类别的数量
count_1 = zeros(8,5);
count_2 = zeros(8,5);
count_3 = zeros(8,5);
count_4 = zeros(8,5);
count_5 = zeros(8,5);
%统计样本中属性的情况
%第1类中第i个属性取值为j的:数目
for i = 1:m1
    if(Y(i)==0)
        counter1 = counter1+1;
        for j = 1:8
            for k = 1:5
                if X(i,j) ==(k-1)
                    count_1(j,k) = count_1(j,k) +1;
                    break;
                end
            end  
        end
    elseif(Y(i)==1)
        counter2 = counter2+1;
        for j = 1:8
            for k = 1:5
                if X(i,j) ==(k-1)
                    count_2(j,k) = count_2(j,k) +1;
                    break;
                end
            end  
        end
    elseif(Y(i)==2)
        counter3 = counter3+1;
        for j = 1:8
            for k = 1:5
                if X(i,j) ==(k-1)
                    count_3(j,k) = count_3(j,k) +1;
                    break;
                end
            end  
        end
    elseif(Y(i)==3)
        counter4 = counter4+1;
        for j = 1:8
            for k = 1:5
                if X(i,j) ==(k-1)
                    count_4(j,k) = count_4(j,k) +1;
                    break;
                end
            end  
        end
    elseif(Y(i)==4)
        counter5 = counter5+1;
        for j = 1:8
            for k = 1:5
                if X(i,j) ==(k-1)
                    count_5(j,k) = count_5(j,k) +1;
                    break;
                end
            end  
        end
    end     
end
y1 = (counter1+1)/(m1+5);
y2 = (counter2+1)/(m1+5);
y3 = (counter3+1)/(m1+5);
y4 = (counter4+1)/(m1+5);
y5 = (counter5+1)/(m1+5);
for i  = 1:8
    for j  = 1:5
        y_1(i,j) = (1+count_1(i,j))/(counter1+laplace(i));
        y_2(i,j) = (count_2(i,j)+1)/(counter2+laplace(i));
        y_3(i,j) = (count_3(i,j)+1)/(counter3+laplace(i));
        y_4(i,j) = (count_4(i,j)+1)/(counter4+laplace(i));
        y_5(i,j) = (count_5(i,j)+1)/(counter5+laplace(i));
    end
end
right_counter  = 0;
rescon = zeros(5,1);
realcon = zeros(5,1);
for i  = 1:m2
    p1 = y1*y_1(1,X2(i,1))*y_1(2,X2(i,2))*y_1(3,X2(i,3))*y_1(4,X2(i,4))...
        *y_1(5,X2(i,5))*y_1(6,X2(i,6))*y_1(7,X2(i,7))*y_1(8,X2(i,8));
    p2 = y2*y_2(1,X2(i,1))*y_2(2,X2(i,2))*y_2(3,X2(i,3))*y_2(4,X2(i,4))...
        *y_2(5,X2(i,5))*y_2(6,X2(i,6))*y_2(7,X2(i,7))*y_2(8,X2(i,8));
    p3 = y3*y_3(1,X2(i,1))*y_3(2,X2(i,2))*y_3(3,X2(i,3))*y_3(4,X2(i,4))...
        *y_3(5,X2(i,5))*y_3(6,X2(i,6))*y_3(7,X2(i,7))*y_3(8,X2(i,8));
    p4 = y4*y_4(1,X2(i,1))*y_4(2,X2(i,2))*y_4(3,X2(i,3))*y_4(4,X2(i,4))...
        *y_4(5,X2(i,5))*y_4(6,X2(i,6))*y_4(7,X2(i,7))*y_4(8,X2(i,8));
    p5 = y5*y_5(1,X2(i,1))*y_5(2,X2(i,2))*y_5(3,X2(i,3))*y_5(4,X2(i,4))...
        *y_5(5,X2(i,5))*y_5(6,X2(i,6))*y_5(7,X2(i,7))*y_5(8,X2(i,8));
 
ans1=(p1)/(p1+p2+p3+p4+p5);
ans2=(p2)/(p1+p2+p3+p4+p5);
ans3=(p3)/(p1+p2+p3+p4+p5);
ans4=(p4)/(p1+p2+p3+p4+p5);
ans5=(p5)/(p1+p2+p3+p4+p5);
ans = [ans1,ans2,ans3,ans4,ans5];
[Y,I] = sort(ans);
rescon(I(5)) = rescon(I(5))+1;
realcon(Y2(i)+1) = realcon(Y2(i)+1) +1;
if(Y2(i)+1 == I(5))
    right_counter = right_counter+1;
end
predict(i) = I(5)-1;
end
rescon = rescon/m2;
realcon = realcon/m2;
x = [0;1;2;3;4];
y = [rescon,realcon];
ycon(:,ooo*2-1) = rescon;
ycon(:,ooo*2) = realcon;
disp('final res:')
disp(right_counter/m2);
disp('sample number:')
disp(nums)
resconsp(res_counter) = right_counter/m2;
predict = predict';
res_counter = res_counter+1;
 
end
size(resconsp)
figure
x = [100:100:10000]
plot(x,resconsp, '--b','Linewidth',2);
grid on
axis([-inf inf 0 1]) 
title('The accuracy of the Naive Bayes model with different size of samples');
xlabel('Size of trainning data')
ylabel('Accuracy')

混淆矩阵

function confusion_matrix(actual,detected)
[mat,order] = confusionmat(actual,detected);
 
imagesc(mat);            %# Create a colored plot of the matrix values
colormap(flipud(gray));  %# Change the colormap to gray (so higher values are
                         %#   black and lower values are white)
                         
textStrings = num2str(mat(:),'%0.02f');  %# Create strings from the matrix values
textStrings = strtrim(cellstr(textStrings));  %# Remove any space padding
 
[x,y] = meshgrid(1:5);   %# Create x and y coordinates for the strings
hStrings = text(x(:),y(:),textStrings(:),...      %# Plot the strings
                'HorizontalAlignment','center');
midValue = mean(get(gca,'CLim'));  %# Get the middle value of the color range
textColors = repmat(mat(:) > midValue,1,3);  %# Choose white or black for the
                                             %#   text color of the strings so
                                             %#   they can be easily seen over
                                             %#   the background color
set(hStrings,{'Color'},num2cell(textColors,2));  %# Change the text colors
 
set(gca,'XTick',1:5,...                         %# Change the axes tick marks
        'XTickLabel',{'0','1','2','3','4'},...  %#   and tick labels
        'YTick',1:5,...
        'YTickLabel',{'0','1','2','3','4'},...
        'TickLength',[0 0]);
xlabel('Real Class');
ylabel('Predict Class');

采样函数

function s = sampling(R, n)
% 选择抽样,R为记录集合,n为抽取的样本数
% 编写函数时用的测试数据
N = length(R);
t = 0;   % 处理过的记录总数
m = 0;   % 已选得的记录数
while 1
    U  = rand;
    if (N-t)*U < n-m
        m = m + 1;
        s(m) = R(t+1);
        % 若已抽取到足够的记录,则算法终止
        if m >= n, break, end
    end
    t = t + 1;
end

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值