项目四:《基于朴素贝叶斯的大规模文章分类模型》

本文介绍了两个基于朴素贝叶斯的文本分类案例。第一个案例涉及三分类,数据包括3067篇已分词的新闻报道,通过字典映射生成训练和测试数据。第二个案例是七分类,训练集包含6300篇文章,测试集700篇,涵盖财经、科技等七个类别。文章需先进行去停用词和分词处理,再映射为id并构建训练和测试数据集。最后,进行了朴素贝叶斯模型的训练与测试。
摘要由CSDN通过智能技术生成

1.案例1:手撸贝叶斯的三分类

1.1 数据介绍

  • data文件中存放3067篇文章
    [root@master nb_test]# cd data
    [root@master data]# ls -ls | wc -l
    3067
    
    每篇文章均是已经分好词的新闻报道在这里插入图片描述
    每篇文章的标题包含了文章的标签,诸如:财经、汽车、运动

1.2 代码介绍

DataConvert.py:将3067篇文章通过字典映射为id,以id格式生成训练数据nb_data.train和nb_data.test。

import sys
import os
import random

WordList = []
WordIDDic = {}
TrainingPercent = 0.8

inpath = sys.argv[1]
OutFileName = sys.argv[2]
trainOutFile = open(OutFileName+".train", "w")
testOutFile = open(OutFileName+".test", "w")

def ConvertData():
    i = 0
    tag = 0
    # 遍历文件夹中的每一个文件
    for filename in os.listdir(inpath):
        if filename.find("business") != -1:
            tag = 1
        elif filename.find("auto") != -1:
            tag = 2
        elif filename.find("sport") != -1:
            tag = 3
        i += 1
        rd = random.random()
        outfile = testOutFile
        if rd < TrainingPercent:
            outfile = trainOutFile

        if i % 100 == 0:
            print(i,"files processed!\r")

        infile = open(inpath+'/'+filename, 'rb')
        outfile.write(str(tag)+" ")
        content = infile.read().strip()
        content = content.decode("utf-8", 'ignore')
        words = content.replace('\n', ' ').split(' ')

		# 构建词典并将id写入输出文件中
        for word in words:
            if len(word.strip()) < 1:
                continue
            if word not in WordIDDic:
                WordList.append(word)
                WordIDDic[word] = len(WordList)
            outfile.write(str(WordIDDic[word])+" ")
            
        outfile.write("#"+filename+"\n")
        infile.close()

    print(i, "files loaded!")
    print(len(WordList), "unique words found!")

ConvertData()
trainOutFile.close()
testOutFile.close()
[root@master nb_test]# python DataConvert.py data/ nb_data
100 files processed!
200 files processed!
300 files processed!
400 files processed!
500 files processed!
600 files processed!
700 files processed!
800 files processed!
900 files processed!
1000 files processed!
1100 files processed!
1200 files processed!
1300 files processed!
1400 files processed!
1500 files processed!
1600 files processed!
1700 files processed!
1800 files processed!
1900 files processed!
2000 files processed!
2100 files processed!
2200 files processed!
2300 files processed!
2400 files processed!
2500 files processed!
2600 files processed!
2700 files processed!
2800 files processed!
2900 files processed!
3000 files processed!
3066 files loaded!
41740 unique words found

生成的nb_data.train和nb_data.test格式如下:
在这里插入图片描述
在这里插入图片描述
NB.py:朴素贝叶斯模型的训练与测试

import sys
import os
import math

DefaultFreq = 0.1
TrainingDataFile = "nb_data.train"
ModelFile = "nb_data.model"
TestDataFile = "nb_data.test"
TestOutFile = "nb_data.out"

ClassFeaDic = {}   # 记录每个类中的每个token的计数
ClassFreq = {}     # 记录每个类的文章个数
WordDic = {}
ClassFeaProb = {}  # 记录每个token在各自类中的概率
ClassDefaultProb = {}  # 记录每个类中未知token的概率
ClassProb = {}


def Dedup(items):
    tempDic = {}
    for item in items:
        if item not in tempDic:
            tempDic[item] = True
    return tempDic.keys()


def LoadData():
    """
    加载数据
    :return:
    """
    i = 0
    infile = open(TrainingDataFile, 'r')
    sline = infile.readline().strip()
    while len(sline) > 0:
        # 去除#标记部分
        pos = sline.find("#")
        if pos > 0:
            sline = sline[:pos].strip()
        words = sline.split(' ')
        if len(words) < 1:
            print("Format error!")
            break
        # 类别号,分类标签:每条样本的第一列
        classid = int(words[0])
        if classid not in ClassFeaDic:
            # 记录每个类中的每个token的计数
            ClassFeaDic[classid] = {}
            # 记录每个token在各自类中的概率
            ClassFeaProb[classid] = {}
            # 记录每个类的文章个数
            ClassFreq[classid] = 0
        ClassFreq[classid] += 1
        # 记录每篇文章的正文文本特征
        words = words[1:]
        # remove duplicate words, binary distribution
        # words = Dedup(words)
        for word in words:
            if len(word) < 1:
                continue
            wid = int(word)
            if wid not in WordDic:
                WordDic[wid] = 1
            # 记录每个类别中的某个单词的计数:p(x|y)
            if wid not in ClassFeaDic[classid]:
                ClassFeaDic[classid][wid] = 1
            else:
                ClassFeaDic[classid][wid] += 1
        i += 1
        sline = infile.readline().strip()
    infile.close()
    print(i, "instances loaded!")
    print(len(ClassFreq), "classes!", len(WordDic), "words!")


def ComputeModel():
    """
    将计数转化为概率
    :return:
    """
    sum = 0.0    # 总文章数
    for freq in ClassFreq.values():
        sum += freq
    # 计算p(yi)
    for classid in ClassFreq.keys():
        # p(yi):先验概率:每个类的文章个数/总文章数
        ClassProb[classid] = (float)(ClassFreq[classid]) / (float)(sum)
    # p(xj|yi)
    # 遍历每个类,针对每一个类,重构ClassFeaProb为概率值
    for classid in ClassFeaDic.keys():
        # 某一个类别中出现词的总数
        sum = 0.0
        for wid in ClassFeaDic[classid].keys():
            sum += ClassFeaDic[classid][wid]
        # newsum = (float)(sum+len(WordDic)*DefaultFreq)
        newsum = (float)(sum + 1)
        # Binary Distribution
        # newsum = (float)(ClassFreq[classid]+2*DefaultFreq)
        for wid in ClassFeaDic[classid].keys():
            # + DefaultFreq:避免为0 ,平滑操作
            ClassFeaProb[classid][wid] = (float)(ClassFeaDic[classid][wid] + DefaultFreq) / newsum
        ClassDefaultProb[classid] = (float)(DefaultFreq) / newsum
    return


def SaveModel():
    """
    保存模型,将概率值存到磁盘
    :return:
    """
    outfile = open(ModelFile, 'w')
    for classid in ClassFreq.keys():
        outfile.write(str(classid))
        outfile.write(' ')
        outfile.write(str(ClassProb[classid]))
        outfile.write(' ')
        outfile.write(str(ClassDefaultProb[classid]))
        outfile.write(' ')
    outfile.write('\n')
    for classid in ClassFeaDic.keys():
        for wid in ClassFeaDic[classid].keys():
            outfile.write(str(wid) + ' ' + str(ClassFeaProb[classid][wid]))
            outfile.write(' ')
        outfile.write('\n')
    outfile.close()


def LoadModel():
    """
    加载模型,即从model文件中加载4个字典,用于预测
    :return:
    """
    global WordDic
    WordDic = {}
    global ClassFeaProb
    ClassFeaProb = {}
    global ClassDefaultProb
    ClassDefaultProb = {}
    global ClassProb
    ClassProb = {}
    infile = open(ModelFile, 'r')
    sline = infile.readline().strip()
    items = sline.split(' ')
    if len(items) < 6:
        print("Model format error!")
        return
    i = 0
    while i < len(items):
        classid = int(items[i])
        ClassFeaProb[classid] = {}
        i += 1
        if i >= len(items):
            print("Model format error!")
            return
        ClassProb[classid] = float(items[i])
        i += 1
        if i >= len(items):
            print("Model format error!")
            return
        ClassDefaultProb[classid] = float(items[i])
        i += 1
    for classid in ClassProb.keys():
        sline = infile.readline().strip()
        items = sline.split(' ')
        i = 0
        while i < len(items):
            wid = int(items[i])
            if wid not in WordDic:
                WordDic[wid] = 1
            i += 1
            if i >= len(items):
                print("Model format error!")
                return
            ClassFeaProb[classid][wid] = float(items[i])
            i += 1
    infile.close()
    print(len(ClassProb), "classes!", len(WordDic), "words!")


def Predict():
    """
    预测
    :return:
    """
    global WordDic
    global ClassFeaProb
    global ClassDefaultProb
    global ClassProb

    TrueLabelList = []
    PredLabelList = []
    i = 0
    infile = open(TestDataFile, 'r')
    outfile = open(TestOutFile, 'w')
    sline = infile.readline().strip()
    # 存储最后的结果:针对每一类的概率值
    # p(yi|X) = p(yj)p(X|yi)
    # p(X|yi) = p(x0|yi)*...*p(xn|yi)
    scoreDic = {}
    iline = 0
    while len(sline) > 0:
        iline += 1
        if iline % 10 == 0:
            print(iline, " lines finished!\r")
        pos = sline.find("#")
        if pos > 0:
            sline = sline[:pos].strip()
        words = sline.split(' ')
        if len(words) < 1:
            print("Format error!")
            break
        classid = int(words[0])
        # 真实标签
        TrueLabelList.append(classid)
        words = words[1:]
        # remove duplicate words, binary distribution
        # words = Dedup(words)
        # p(yi)
        for classid in ClassProb.keys():
            scoreDic[classid] = math.log(ClassProb[classid])
        # log(p(xi|yi))累加
        for word in words:
            if len(word) < 1:
                continue
            wid = int(word)
            if wid not in WordDic:
                # print "OOV word:",wid
                continue
            # 遍历标签,计算每个标签下的概率
            for classid in ClassProb.keys():
                if wid not in ClassFeaProb[classid]:
                    scoreDic[classid] += math.log(ClassDefaultProb[classid])
                else:
                    scoreDic[classid] += math.log(ClassFeaProb[classid][wid])
        # binary distribution
        # wid = 1
        # while wid < len(WordDic)+1:
        #   if str(wid) in words:
        #       wid += 1
        #       continue
        #   for classid in ClassProb.keys():
        #       if wid not in ClassFeaProb[classid]:
        #           scoreDic[classid] += math.log(1-ClassDefaultProb[classid])
        #       else:
        #           scoreDic[classid] += math.log(1-ClassFeaProb[classid][wid])
        #   wid += 1
        # 概率最大的标签为预测标签
        i += 1
        maxProb = max(scoreDic.values())
        for classid in scoreDic.keys():
            if scoreDic[classid] == maxProb:
                # 预测标签
                PredLabelList.append(classid)
        sline = infile.readline().strip()
    infile.close()
    outfile.close()
    print(len(PredLabelList), len(TrueLabelList))
    return TrueLabelList, PredLabelList


def Evaluate(TrueList, PredList):
    """
    准确率
    :param TrueList:
    :param PredList:
    :return:
    """
    accuracy = 0
    i = 0
    while i < len(TrueList):
        if TrueList[i] == PredList[i]:
            accuracy += 1
        i += 1
    # 准确率
    accuracy = (float)(accuracy) / (float)(len(TrueList))
    print("Accuracy:", accuracy)


def CalPreRec(TrueList, PredList, classid):
    """
    计算精确率与召回率
    :param TrueList:
    :param PredList:
    :param classid:
    :return:
    """
    correctNum = 0
    allNum = 0
    predNum = 0
    i = 0
    while i < len(TrueList):
        if TrueList[i] == classid:
            allNum += 1
            if PredList[i] == TrueList[i]:
                correctNum += 1
        if PredList[i] == classid:
            predNum += 1
        i += 1
    return (float)(correctNum) / (float)(predNum), (float)(correctNum) / (float)(allNum)


# main framework
if len(sys.argv) < 4:
    print("Usage incorrect!")
# 1代表训练
elif sys.argv[1] == '1':
    print("start training:")
    TrainingDataFile = sys.argv[2]
    ModelFile = sys.argv[3]
    LoadData()
    ComputeModel()
    SaveModel()
# 0代表测试
elif sys.argv[1] == '0':
    print("start testing:")
    TestDataFile = sys.argv[2]
    ModelFile = sys.argv[3]
    TestOutFile = sys.argv[4]

    LoadModel()
    TList, PList = Predict()
    i = 0
    outfile = open(TestOutFile, 'w')
    while i < len(TList):
        outfile.write(str(TList[i]))
        outfile.write(' ')
        outfile.write(str(PList[i]))
        outfile.write('\n')
        i += 1
    outfile.close()
    Evaluate(TList, PList)
    # 计算每个标签的精确率与召回率
    for classid in ClassProb.keys():
        pre, rec = CalPreRec(TList, PList, classid)
        print("Precision and recall for Class", classid, ":", pre, rec)
else:
    print("Usage incorrect!")
[root@master nb_test]# python NB.py 1 nb_data.train model
start training:
2408 instances loaded!
3 classes! 37918 words!
[root@master nb_test]# python NB.py 0 nb_data.test model out
start testing:
3 classes! 37918 words!
10  lines finished!
20  lines finished!
30  lines finished!
40  lines finished!
50  lines finished!
60  lines finished!
70  lines finished!
80  lines finished!
90  lines finished!
100  lines finished!
110  lines finished!
120  lines finished!
130  lines finished!
140  lines finished!
150  lines finished!
160  lines finished!
170  lines finished!
180  lines finished!
190  lines finished!
200  lines finished!
210  lines finished!
220  lines finished!
230  lines finished!
240  lines finished!
250  lines finished!
260  lines finished!
270  lines finished!
280  lines finished!
290  lines finished!
300  lines finished!
310  lines finished!
320  lines finished!
330  lines finished!
340  lines finished!
350  lines finished!
360  lines finished!
370  lines finished!
380  lines finished!
390  lines finished!
400  lines finished!
410  lines finished!
420  lines finished!
430  lines finished!
440  lines finished!
450  lines finished!
460  lines finished!
470  lines finished!
480  lines finished!
490  lines finished!
500  lines finished!
510  lines finished!
520  lines finished!
530  lines finished!
540  lines finished!
550  lines finished!
560  lines finished!
570  lines finished!
580  lines finished!
590  lines finished!
600  lines finished!
610  lines finished!
620  lines finished!
630  lines finished!
640  lines finished!
650  lines finished!
658 658
Accuracy: 0.9741641337386018
Precision and recall for Class 1 : 0.9590909090909091 0.9768518518518519
Precision and recall for Class 2 : 0.9771689497716894 0.9511111111111111
Precision and recall for Class 3 : 0.9863013698630136 0.9953917050691244

2.案例2.手撸贝叶斯的七分类

2.1 数据介绍

数据:
训练集:6300篇
测试集:700篇
标记:1 财经;2 科技;3 汽车;4 房产;5 体育;6 娱乐;7 其他

其中,每篇文章的格式如下所示:
在这里插入图片描述
因此,这些文章需要先进行去停用词、分词操作。

2.2 代码介绍

fenci.py:将每一篇文章进行分词

import sys
import os
import jieba

inpath = sys.argv[1]
outpath = sys.argv[2]
stopWordpath = open('data/stop.txt', 'rb')


def fenci():
    # 加载停用词表
    stopword = [line.strip().decode('utf-8') for line in stopWordpath.readlines()]
    # 遍历输入文件夹下的每一个文件
    for filename in os.listdir(inpath):
        inputfile = open(inpath + '/' + filename, 'rb')
        outputfile = open(outpath + '/' + filename, 'w', encoding='utf-8')
        content = inputfile.read().strip()
        content = content.decode('utf-8', 'ignore')

        seg_lists = jieba.lcut(content, cut_all=False)

        for seg in seg_lists:
            if seg not in stopword:
                outputfile.write(seg)
                outputfile.write(' ')

    inputfile.close()
    outputfile.close()


if __name__ == '__main__':
    fenci()

在这里插入图片描述
DataConvert:将分好词的文章映射为id,并将多篇文章存储到一篇文章中

import sys
import os

WordList = []
WordIDDic = {}

inTrainpath = sys.argv[1]
inTestpath = sys.argv[2]
OutFileName = sys.argv[3]
trainOutFile = open(OutFileName + ".train", "w")
testOutFile = open(OutFileName + ".test", "w")


def ConvertData():
    i = 0
    tag = 0
    for filename in os.listdir(inTrainpath):
        if filename[0] == str(1):
            tag = 1
        elif filename[0] == str(2):
            tag = 2
        elif filename[0] == str(3):
            tag = 3
        elif filename[0] == str(4):
            tag = 4
        elif filename[0] == str(5):
            tag = 5
        elif filename[0] == str(6):
            tag = 6
        elif filename[0] == str(7):
            tag = 7
        i += 1
        outfile = trainOutFile

        if i % 100 == 0:
            print(i, "files processed!\r")

        infile = open(inTrainpath + '/' + filename, 'rb')
        outfile.write(str(tag) + " ")
        content = infile.read().strip()
        content = content.decode("utf-8", 'ignore')
        words = content.replace('\n', ' ').split(' ')
        for word in words:
            if len(word.strip()) < 1:
                continue
            if word not in WordIDDic:
                WordList.append(word)
                WordIDDic[word] = len(WordList)
            outfile.write(str(WordIDDic[word])+" ")
        outfile.write("#"+filename+"\n")
        infile.close()


    for filename in os.listdir(inTestpath):
        if filename[0] == str(1):
            tag = 1
        elif filename[0] == str(2):
            tag = 2
        elif filename[0] == str(3):
            tag = 3
        elif filename[0] == str(4):
            tag = 4
        elif filename[0] == str(5):
            tag = 5
        elif filename[0] == str(6):
            tag = 6
        elif filename[0] == str(7):
            tag = 7
        i += 1
        outfile = testOutFile

        if i % 100 == 0:
            print(i, "files processed!\r")

        infile = open(inTestpath + '/' + filename, 'rb')
        outfile.write(str(tag) + " ")
        content = infile.read().strip()
        content = content.decode("utf-8", 'ignore')
        words = content.replace('\n', ' ').split(' ')
        for word in words:
            if len(word.strip()) < 1:
                continue
            if word not in WordIDDic:
                WordList.append(word)
                WordIDDic[word] = len(WordList)
            outfile.write(str(WordIDDic[word])+" ")
        outfile.write("#"+filename+"\n")
        infile.close()

    print(i, "files loaded!")
    print(len(WordList), "unique words found!")

ConvertData()
trainOutFile.close()
testOutFile.close()


分别生成nb_data.train和nb_data.test文件。其格式如下:
在这里插入图片描述
NB.py:朴素贝叶斯模型的训练与测试

# coding=utf8
# Usage:
# Training: NB.py 1 TrainingDataFile ModelFile
# Testing: NB.py 0 TestDataFile ModelFile OutFile

import sys
import os
import math

DefaultFreq = 0.1
TrainingDataFile = "nb_data.train"
ModelFile = "nb_data.model"
TestDataFile = "nb_data.test"
TestOutFile = "nb_data.out"

ClassFeaDic = {}   # 记录每个类中的每个token的计数
ClassFreq = {}     # 记录每个类的文章个数
WordDic = {}
ClassFeaProb = {}  # 记录每个token在各自类中的概率
ClassDefaultProb = {}  # 记录每个类中未知token的概率
ClassProb = {}


def Dedup(items):
    tempDic = {}
    for item in items:
        if item not in tempDic:
            tempDic[item] = True
    return tempDic.keys()


def LoadData():
    """
    加载数据
    :return:
    """
    i = 0
    infile = open(TrainingDataFile, 'r')
    sline = infile.readline().strip()
    while len(sline) > 0:
        # 去除#标记部分
        pos = sline.find("#")
        if pos > 0:
            sline = sline[:pos].strip()
        words = sline.split(' ')
        if len(words) < 1:
            print("Format error!")
            break
        # 类别号,分类标签:每条样本的第一列
        classid = int(words[0])
        if classid not in ClassFeaDic:
            # 记录每个类中的每个token的计数
            ClassFeaDic[classid] = {}
            # 记录每个token在各自类中的概率
            ClassFeaProb[classid] = {}
            # 记录每个类的文章个数
            ClassFreq[classid] = 0
        ClassFreq[classid] += 1
        # 记录每篇文章的正文文本特征
        words = words[1:]
        # remove duplicate words, binary distribution
        # words = Dedup(words)
        for word in words:
            if len(word) < 1:
                continue
            wid = int(word)
            if wid not in WordDic:
                WordDic[wid] = 1
            # 记录每个类别中的某个单词的计数:p(x|y)
            if wid not in ClassFeaDic[classid]:
                ClassFeaDic[classid][wid] = 1
            else:
                ClassFeaDic[classid][wid] += 1
        i += 1
        sline = infile.readline().strip()
    infile.close()
    print(i, "instances loaded!")
    print(len(ClassFreq), "classes!", len(WordDic), "words!")


def ComputeModel():
    """
    将计数转化为概率
    :return:
    """
    sum = 0.0    # 总文章数
    for freq in ClassFreq.values():
        sum += freq
    # 计算p(yi)
    for classid in ClassFreq.keys():
        # p(yi):先验概率:每个类的文章个数/总文章数
        ClassProb[classid] = (float)(ClassFreq[classid]) / (float)(sum)
    # p(xj|yi)
    # 遍历每个类,针对每一个类,重构ClassFeaProb为概率值
    for classid in ClassFeaDic.keys():
        # 某一个类别中出现词的总数
        sum = 0.0
        for wid in ClassFeaDic[classid].keys():
            sum += ClassFeaDic[classid][wid]
        # newsum = (float)(sum+len(WordDic)*DefaultFreq)
        newsum = (float)(sum + 1)
        # Binary Distribution
        # newsum = (float)(ClassFreq[classid]+2*DefaultFreq)
        for wid in ClassFeaDic[classid].keys():
            # + DefaultFreq:避免为0 ,平滑操作
            ClassFeaProb[classid][wid] = (float)(ClassFeaDic[classid][wid] + DefaultFreq) / newsum
        ClassDefaultProb[classid] = (float)(DefaultFreq) / newsum
    return


def SaveModel():
    """
    保存模型,将概率值存到磁盘
    :return:
    """
    outfile = open(ModelFile, 'w')
    for classid in ClassFreq.keys():
        outfile.write(str(classid))
        outfile.write(' ')
        outfile.write(str(ClassProb[classid]))
        outfile.write(' ')
        outfile.write(str(ClassDefaultProb[classid]))
        outfile.write(' ')
    outfile.write('\n')
    for classid in ClassFeaDic.keys():
        for wid in ClassFeaDic[classid].keys():
            outfile.write(str(wid) + ' ' + str(ClassFeaProb[classid][wid]))
            outfile.write(' ')
        outfile.write('\n')
    outfile.close()


def LoadModel():
    """
    加载模型,即从model文件中加载4个字典,用于预测
    :return:
    """
    global WordDic
    WordDic = {}
    global ClassFeaProb
    ClassFeaProb = {}
    global ClassDefaultProb
    ClassDefaultProb = {}
    global ClassProb
    ClassProb = {}
    infile = open(ModelFile, 'r')
    sline = infile.readline().strip()
    items = sline.split(' ')
    if len(items) < 6:
        print("Model format error!")
        return
    i = 0
    while i < len(items):
        classid = int(items[i])
        ClassFeaProb[classid] = {}
        i += 1
        if i >= len(items):
            print("Model format error!")
            return
        ClassProb[classid] = float(items[i])
        i += 1
        if i >= len(items):
            print("Model format error!")
            return
        ClassDefaultProb[classid] = float(items[i])
        i += 1
    for classid in ClassProb.keys():
        sline = infile.readline().strip()
        items = sline.split(' ')
        i = 0
        while i < len(items):
            wid = int(items[i])
            if wid not in WordDic:
                WordDic[wid] = 1
            i += 1
            if i >= len(items):
                print("Model format error!")
                return
            ClassFeaProb[classid][wid] = float(items[i])
            i += 1
    infile.close()
    print(len(ClassProb), "classes!", len(WordDic), "words!")


def Predict():
    """
    预测
    :return:
    """
    global WordDic
    global ClassFeaProb
    global ClassDefaultProb
    global ClassProb

    TrueLabelList = []
    PredLabelList = []
    i = 0
    infile = open(TestDataFile, 'r')
    outfile = open(TestOutFile, 'w')
    sline = infile.readline().strip()
    # 存储最后的结果:针对每一类的概率值
    # p(yi|X) = p(yj)p(X|yi)
    # p(X|yi) = p(x0|yi)*...*p(xn|yi)
    scoreDic = {}
    iline = 0
    while len(sline) > 0:
        iline += 1
        if iline % 10 == 0:
            print(iline, " lines finished!\r")
        pos = sline.find("#")
        if pos > 0:
            sline = sline[:pos].strip()
        words = sline.split(' ')
        if len(words) < 1:
            print("Format error!")
            break
        classid = int(words[0])
        # 真实标签
        TrueLabelList.append(classid)
        words = words[1:]
        # remove duplicate words, binary distribution
        # words = Dedup(words)
        # p(yi)
        for classid in ClassProb.keys():
            scoreDic[classid] = math.log(ClassProb[classid])
        # log(p(xi|yi))累加
        for word in words:
            if len(word) < 1:
                continue
            wid = int(word)
            if wid not in WordDic:
                # print "OOV word:",wid
                continue
            # 遍历标签,计算每个标签下的概率
            for classid in ClassProb.keys():
                if wid not in ClassFeaProb[classid]:
                    scoreDic[classid] += math.log(ClassDefaultProb[classid])
                else:
                    scoreDic[classid] += math.log(ClassFeaProb[classid][wid])
        # binary distribution
        # wid = 1
        # while wid < len(WordDic)+1:
        #   if str(wid) in words:
        #       wid += 1
        #       continue
        #   for classid in ClassProb.keys():
        #       if wid not in ClassFeaProb[classid]:
        #           scoreDic[classid] += math.log(1-ClassDefaultProb[classid])
        #       else:
        #           scoreDic[classid] += math.log(1-ClassFeaProb[classid][wid])
        #   wid += 1
        # 概率最大的标签为预测标签
        i += 1
        maxProb = max(scoreDic.values())
        for classid in scoreDic.keys():
            if scoreDic[classid] == maxProb:
                # 预测标签
                PredLabelList.append(classid)
        sline = infile.readline().strip()
    infile.close()
    outfile.close()
    print(len(PredLabelList), len(TrueLabelList))
    return TrueLabelList, PredLabelList


def Evaluate(TrueList, PredList):
    """
    准确率
    :param TrueList:
    :param PredList:
    :return:
    """
    accuracy = 0
    i = 0
    while i < len(TrueList):
        if TrueList[i] == PredList[i]:
            accuracy += 1
        i += 1
    # 准确率
    accuracy = (float)(accuracy) / (float)(len(TrueList))
    print("Accuracy:", accuracy)


def CalPreRec(TrueList, PredList, classid):
    """
    计算精确率与召回率
    :param TrueList:
    :param PredList:
    :param classid:
    :return:
    """
    correctNum = 0
    allNum = 0
    predNum = 0
    i = 0
    while i < len(TrueList):
        if TrueList[i] == classid:
            allNum += 1
            if PredList[i] == TrueList[i]:
                correctNum += 1
        if PredList[i] == classid:
            predNum += 1
        i += 1
    return (float)(correctNum) / (float)(predNum), (float)(correctNum) / (float)(allNum)


# main framework
if len(sys.argv) < 4:
    print("Usage incorrect!")
# 1代表训练
elif sys.argv[1] == '1':
    print("start training:")
    TrainingDataFile = sys.argv[2]
    ModelFile = sys.argv[3]
    LoadData()
    ComputeModel()
    SaveModel()
# 0代表测试
elif sys.argv[1] == '0':
    print("start testing:")
    TestDataFile = sys.argv[2]
    ModelFile = sys.argv[3]
    TestOutFile = sys.argv[4]

    LoadModel()
    TList, PList = Predict()
    i = 0
    outfile = open(TestOutFile, 'w')
    while i < len(TList):
        outfile.write(str(TList[i]))
        outfile.write(' ')
        outfile.write(str(PList[i]))
        outfile.write('\n')
        i += 1
    outfile.close()
    Evaluate(TList, PList)
    # 计算每个标签的精确率与召回率
    for classid in ClassProb.keys():
        pre, rec = CalPreRec(TList, PList, classid)
        print("Precision and recall for Class", classid, ":", pre, rec)
else:
    print("Usage incorrect!")

程序运行结果为:

(tf2) D:\PycharmProjects\数据挖掘与推荐系统\机器学习与深度学习基础\nb_homework>python NB.py 1 nb_data.train model
start training:
6300 instances loaded!
7 classes! 120865 words!

(tf2) D:\PycharmProjects\数据挖掘与推荐系统\机器学习与深度学习基础\nb_homework>python NB.py 0 nb_data.test model out
start testing:
7 classes! 120865 words!
10  lines finished!
20  lines finished!
30  lines finished!
40  lines finished!
50  lines finished!
60  lines finished!
70  lines finished!
80  lines finished!
90  lines finished!
100  lines finished!
110  lines finished!
120  lines finished!
130  lines finished!
140  lines finished!
150  lines finished!
160  lines finished!
170  lines finished!
180  lines finished!
190  lines finished!
200  lines finished!
210  lines finished!
220  lines finished!
230  lines finished!
240  lines finished!
250  lines finished!
260  lines finished!
270  lines finished!
280  lines finished!
290  lines finished!
300  lines finished!
310  lines finished!
320  lines finished!
330  lines finished!
340  lines finished!
350  lines finished!
360  lines finished!
370  lines finished!
380  lines finished!
390  lines finished!
400  lines finished!
410  lines finished!
420  lines finished!
430  lines finished!
440  lines finished!
450  lines finished!
460  lines finished!
470  lines finished!
480  lines finished!
490  lines finished!
500  lines finished!
510  lines finished!
520  lines finished!
530  lines finished!
540  lines finished!
550  lines finished!
560  lines finished!
570  lines finished!
580  lines finished!
590  lines finished!
600  lines finished!
610  lines finished!
620  lines finished!
630  lines finished!
640  lines finished!
650  lines finished!
660  lines finished!
670  lines finished!
680  lines finished!
690  lines finished!
700  lines finished!
700 700
Accuracy: 0.8657142857142858
Precision and recall for Class 1 : 0.8021978021978022 0.73
Precision and recall for Class 2 : 0.831858407079646 0.94
Precision and recall for Class 3 : 0.9595959595959596 0.95
Precision and recall for Class 4 : 0.9318181818181818 0.82
Precision and recall for Class 5 : 1.0 0.95
Precision and recall for Class 6 : 0.868421052631579 0.99
Precision and recall for Class 7 : 0.68 0.68
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

幼稚的人呐

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值