Python数据科学:朴素贝叶斯分类

4.4朴素贝叶斯分类

朴素贝叶斯分类是非常简单快速的分类算法,适用于高维度数据集。

基础数学原理:贝叶斯定理

一个统计量条件概率公式,判断样本属于某个标签的概率。

P(L|f)表示在f特征下,样本属于某类标签的概率。

确定标签数据的随机过程,通常假设数据符合某种分布,确定使用的模型。

假设每个标签的数据符合高斯分布,那么使用高斯朴素贝叶斯模型

实例:高斯朴素贝叶斯

def skLearn10():

    '''

    高斯朴素贝叶斯

    :return:

    '''

    from sklearn.datasets import make_blobs

    #模拟特征矩阵,目标值

    X,y = make_blobs(100,2,centers=2,random_state=3,cluster_std=1.4)

    #(100,2)

    print(X.shape)

    plt.scatter(X[:,0],X[:,1],c=y,s=50,cmap='RdBu')

    #显示图片

    #plt.show()



    #使用高斯朴素贝叶斯模型

    from sklearn.naive_bayes import GaussianNB

    model = GaussianNB()

    #训练数据

    model.fit(X,y)



    #生成测试数据

    rng = np.random.RandomState(0)

    Xnew = [12,14] * rng.rand(2000,2)

    #判断样本属于某个标签概率

    ytest = model.predict_proba(Xnew)

    print(ytest.round(2))

实例2:多项式朴素贝叶斯

def skLearn11():

    '''

    多项式朴素贝叶斯

    :return:

    '''

    #获取数据,访问不了。

    from sklearn.datasets import  fetch_20newsgroups

    data = fetch_20newsgroups()

    #选择主题

    subjects = ['alt.atheism','comp.graphics','comp.windows.x','rec.autos']

    #获取训练或测试数据集

    train = fetch_20newsgroups(subset='train',categories=subjects)

    test = fetch_20newsgroups(subset='test',categories=subjects)

    #选择模型

    from sklearn.feature_extraction.text import TfidfVectorizer

    from sklearn.naive_bayes import MultinomialNB

    from sklearn.pipeline import make_pipeline

    #使用pipline将多个操作组合

    model = make_pipeline(TfidfVectorizer(),MultinomialNB())

    #拟合数据

    model.fit(train.data,train.target)

    #预测数据

    labels = model.predict(test.data)



    #使用混淆矩阵

    from sklearn.metrics import confusion_matrix

    mat = confusion_matrix(test.target,labels)

    #显示预测效果

    sns.heatmap(mat.T,square=True,annot=True,fmt='d',cbar=False,\

                xticklabels=train.target_names,\

                yticklabels=train.target_names)

    plt.xlabel('true  label')

    plt.ylabel('predict label')

实例3:

数据集地址:

http://www.cs.cmu.edu/afs/cs.cmu.edu/project/theo-20/www/data/news20.html

使用本地news数据,进行多项式朴素贝叶斯文本类型news分析

import os

def get_path_list(type,subjects):

    '''

    获取文件列表

    :param type:

    :param subjects:

    :return:

    '''

    fold_path = ''

    path_list = []

    type_list = []

    for subject in subjects:

        if type == 'test':

            #将news文件放在data文件夹下

            fold_path = './data/20news-bydate-test/' + subject

        elif type == 'train':

            fold_path =  './data/20news-bydate-train/' + subject

        for dir in os.listdir(fold_path):

            path_list.append(fold_path + '/'+dir)

            type_list.append(subject)

    return  path_list,type_list



def get_data_list(path_list):

    '''

    获取数据

    :param path_list:

    :return:

    '''

    fileContent = []

    for path in path_list:

        with open(path, 'rt',errors='ignore') as file_obj:

            file_str = file_obj.read()

            fileContent.append(file_str)

    return fileContent



def skLearn12():

    '''

    多项式朴素贝叶斯

    :return:

    '''

    #选择主题

    subjects = ['alt.atheism', 'comp.graphics', 'comp.windows.x', 'rec.autos']



    # 获取train内容

    #获取路径

    path_list,type_list = get_path_list('train',subjects)

    files_list = get_data_list(path_list)

    #构建数据框体

    df_news = pd.DataFrame({'data':files_list,'target':type_list})

    #获取test内容

    path_test,type_test = get_path_list('test',subjects)

    files_test = get_data_list(path_test)

    df_test = pd.DataFrame({'data': files_test, 'target': type_test})





    #选择模型

    from sklearn.feature_extraction.text import TfidfVectorizer

    from sklearn.naive_bayes import MultinomialNB

    from sklearn.pipeline import make_pipeline

    #使用pipline将多个操作组合

    model = make_pipeline(TfidfVectorizer(),MultinomialNB())

    #拟合数据

    model.fit(df_news.data,df_news.target)

    #预测数据

    labels = model.predict(df_test.data)



    #获取预测效果

    #使用混淆矩阵

    from sklearn.metrics import confusion_matrix

    mat = confusion_matrix(df_test.target,labels)

    #显示预测效果

    fig,ax = plt.subplots(figsize=(8,8))

    sns.heatmap(mat,square=True,annot=True,fmt='d',cbar=False,\

                xticklabels=subjects,\

                yticklabels=subjects

                )

    plt.xlabel('true  label')

    plt.ylabel('predict label')

    #设置旋转

    ax.set_xticklabels(ax.get_xticklabels(), rotation=360)

    #显示图片

    plt.show()

  • 7
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

偶是不器

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值