本次实验的任务是文本分类,所用数据集为20news_bydate
实验流程主要如下:
一、算法介绍
· 朴素贝叶斯算法是应用最为广泛的分类算法之一。
· 给定训练数据集X,其类别为Y,则有:
其中,P(Y|X)为后验概率,即测试集文本X取类别Y的概率。对于文本Xi,对每个类别Yi计算概率,得到P(Y1|Xi)、P(Y2|Xi)、……、P(Yn|Xi),概率最大的类别则为预测类别。
上式中:
· P(Y)为每个类别的先验概率,计算方式为:
每个类别单词总数 / 训练集所有单词总数。
· P(X | Y)= P(x1,x2,……,xn | Y),xi为文本的特征。
朴素贝叶斯“朴素”地将各个特征视作相互独立,因此:
P(X | Y) = P(x1 | Y)× P(x2 | Y)× P(x3 | Y)…… × P(xn | Y)。
这里,文本的特征xi即为文本预处理、分词后得到的全部词汇。
计算方式(以 P(xi | Yj)为例):
单词xi在训练集Yj类别下所有文档中出现的总次数 / 训练集Yj类别下所有文档包含的单词总数
· P(X)始终不变,可以忽略不计。
综上,实验主要步骤为:
数据集获取 -》 数据预处理,得到分词后的语料和词袋 -》 遍历计算P(X)、P(X|Y) -》
max( P(X) × P(X|Y) )对应的类别即为预测类别
二、数据集获取
20newsgroups数据集可以直接从sklearn模块导入:
from sklearn.datasets import fetch_20newsgroups #导入模块
news_data = fetch_20newsgroups(subset="all") #读取数据
也可以手动下载:
链接:https://pan.baidu.com/s/1YO-Je1lT_y-MbGRpSwHQSA
提取码:qwer
三、数据预处理
import os
# string nltk 用于文本预处理
import string
import nltk
from nltk.corpus import stopwords
import pickle
class Textprocess():
def __init__(self):
# 存放原始语料库路径
self.corpus = ''
# 分词后的路径
self.segment = ''
# 存储分词、去重后的结果
self.word_list = ''
self.label_list = ''
# 存储训练集分词结果,处理测试集时相关代码注释
self.ori_words_list = ''
# 在原始路径下创建train_segment,test_segment两个文件夹
# 存储预处理、分词后的结果
def preprocess(self):
mydir = os.listdir(self.corpus)
for dir in mydir:
create_dir = self.corpus + '/' + dir + '_segment'
os.makedirs(create_dir)
dir_path = self.corpus + '/' + dir
news_list = os.listdir(dir_path)
# 每个类别的文档集
for news in news_list:
path = create_dir + '/' + news
os.makedirs(path)
news_path = dir_path + '/' + news
files = os.listdir(news_path)
# 每个文本文件
for file in files:
file_path = news_path + '/' + file
with open(file_path,'r',encoding='utf-8', errors='ignore') as f1:
content = f1.read()
clean_content = self.data_clean(content)
new_file_path = path + '/' + file
with open(new_file_path, 'w', encoding = 'utf-8', errors='ignore') as f2:
f2.write(clean_content)
def data_clean(self, data):
# 大写转换为小写
data1 = data.lower()
# 去除标点符号
remove = str.maketrans('','',string.punctuation)
data2 = data1.translate(remove)
# 分词
data3 = nltk.word_tokenize(data2)
# 去除停用词和非英文词汇
data4 = [w for w in data3 if (w not in stopwords.words('english')) and (w.isalpha()) and (len(w) < 15)]
data_str = ' '.join(data4)
return data_str
def create_non_repeated_words:(self):
self.content_list = []
self.labels_list = []
# self.ori_list = []
mydir = sorted(os.listdir(self.segment))
label = 0
for dir in mydir:
dir_path = self.segment + '/' + dir
files = sorted(os.listdir(dir_path))
for file in files:
file_path = dir_path + '/' + file
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
line = f.read()
line1 = line.strip('\n')
line2 = line1.split()
# 将分词后的结果存入列表,处理测试集时注释
# self.ori_list.append(line2)
# 列表去重,并去掉仅出现一次的单词
line3 = []
once_word = []
for i in line2:
if i not in once_word:
once_word.append(i)
else:
if i not in line3:
line3.append(i)
self.content_list.append(line3)
self.labels_list.append(label)
label += 1
self.data_dump(self.word_list, self.content_list)
self.data_dump(self.label_list, self.labels_list)
# self.data_dump(self.ori_words_list, self.ori_list)
def data_dump(self, path, data):
f = open(path, 'wb')
pickle.dump(data, f)
f.close()
def data_load(self, path):
f = open(path, 'rb')
data = pickle.load(f)
return data
text = Textprocess()
text.corpus = r'.\20news-bydate'
text.segment = r'.\20news-bydate\20news-bydate-train_segment'
text.word_list = 'train_words'
text.label_list = 'train_labels'
text.ori_words_list = 'original_bag'
text.preprocess()
text.create_non_repeated_words()
test = Textprocess()
test.corpus = r'.\20news-bydate'
test.segment = r'.\20news-bydate\20news-bydate-test_segment'
test.word_list = 'test_words'
test.label_list = 'test_labels'
test.create_non_repeated_words()
四、朴素贝叶斯算法的实现
#encoding=utf-8
import textprocess_detail as tp
import numpy as np
from sklearn import metrics
ori_words = tp.Textprocess().data_load('original_bag')
train_labels = tp.Textprocess().data_load('train_labels')
test_words = tp.Textprocess().data_load('test_words')
test_labels = tp.Textprocess().data_load('test_labels')
# 计算每个类别包含的单词总数
def words_sum():
sum = [0 for i in range(20)]
for i in range(len(ori_words)):
count = len(ori_words[i])
sum[train_labels[i]] += count
return sum
# 计算每个类别的先验概率
def category_probability(list):
sum = 0
cp = []
for i in list:
sum += i
for j in list:
cp.append(j / sum)
return cp
# p(x1|y) * p(x2|y) * …… * p(y)
def predict(sum, cp):
precision = []
for doc in range(len(test_words)):
p_list = []
for predict_label in range(20):
p = 1
word_sum = sum[predict_label]
for word in test_words[doc]:
count = 0
# 遍历训练集文档
for i in range(len(ori_words)):
if train_labels[i] == predict_label:
count += ori_words[i].count(word)
p *= (count + 1) / (word_sum + 20)
p *= cp[predict_label]
p_list.append(p)
precision.append(p_list)
tp.Textprocess().data_dump('precision', precision)
print(precision)
return precision
count_list = words_sum()
cp = category_probability(count_list)
precision = predict(count_list, cp)
probability = tp.Textprocess().data_load('precision')
a = np.array(probability)
precision = np.argmax(a, axis = 1)
true_label = tp.Textprocess().data_load('test_labels')
m_precision = metrics.accuracy_score(true_label,precision)
print("%.2f"%m_precision)
最终正确率在0.7左右