朴素贝叶斯实现文档分类
网上与朴素贝叶斯相关的内容很多,本文仅作为作业的记录博客,重点记录在实现作业过程中遇到的问题和解决方法。
作业要求:
- 实验数据在bayes_datasets文件夹中。其中,
train为训练数据集,包含hotel和travel两个中文文本集,文本为txt格式。hotel文本集中全部都是介绍酒店信息的文档,travel文本集中全部都是介绍景点信息的文档;
Bayes_datasets/test为测试数据集,包含若干hotel类文档和travel类文档。 - 用朴素贝叶斯算法对上述两类文档进行分类。要求输出测试数据集的文档分类结果,即每类文档的数量。
(例:hotel:XX,travel:XX)
贝叶斯公式:
朴素贝叶斯算法的核心,贝叶斯公式如下:
换个表达形式:
代码实现:
第一部分:读取数据
f_path = os.path.abspath('.')+'/bayes_datasets/train/hotel'
f1_path = os.path.abspath('.')+'/bayes_datasets/train/travel'
f2_path = os.path.abspath('.')+'/bayes_datasets/test'
ls = os.listdir(f_path)
ls1 = os.listdir(f1_path)
ls2 = os.listdir(f2_path)
#去掉网址的正则表达式
pattern = r"(http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*,]|(?:%[0-9a-fA-F][0-9a-fA-F]))+)|([a-zA-Z]+.\w+\.+[a-zA-Z0-9\/_]+)"
res = []
for i in ls:
with open(str(f_path+'\\'+i),encoding='UTF-8') as f:
lines = f.readlines()
tmp = ''.join(str(i.replace('\n','')) for i in lines)
tmp = re.sub(pattern,'',tmp)
remove_digits = str.maketrans('',