首先看一下工程目录:
mongo_client目录下存放的是mongodb数据库的连接,以及数据的获取
navie_bayes目录下存放的是朴素贝叶斯的实现
tags_posterior目录下存放的是已经计算好样本的标签的后验概率
tags_priori 目录下存放的是计算好的样本标签先验概率
training目录下存放的样本训练的方法
接下来简单的介绍一下各目录下的python实现:
mongo_client/mongodb_client.py
class MongodbClient(object):
"""
connect mongodb
"""
def __init__(self):
db_client = MongoClient("localhost", 27017)
db = db_client.toutiao
self.collection = db.sample
def find_all_db_tags(self):
"""
collect all distinct tags_posterior in mongodb
:return:
"""
tags_list = self.collection.distinct('tag')
if len(tags_list) == 0:
print 'no tags_posterior in db, please checkout the connect of mongodb'
return tags_list
def find_all_dir_tags(self):
"""
to find tags_posterior have existed in tags_posterior dir
the format of tags_posterior like 'tag.txt' not 'tag'
:return:
"""
tags_list = os.listdir('tags_posterior/')
return tags_list
def create_tag_file(self):
"""
create all kind of tags_posterior in this path tags_posterior/
:return:
"""
if not os.path.isdir('tags_posterior'):
os.makedirs('tags_posterior')
tags_list = self.find_all_db_tags()
if len(tags_list) > 0:
for tag in tags_list:
filename = 'tags_posterior/' + tag.encode('utf-8') + '.txt' # tag is unicode, but we declare all file is utf-8
if os.path.exists(filename):
continue
else:
f = open(filename, 'wb')
def find_all_articles(self):
"""
search all key-value dic result from mongodb
:return:
"""
article_list = self.collection.find({'content': {'$exists': True, '$ne': '无'}})
if article_list.count == 0:
print 'no data in database, please checkout the database'
else:
return article_list
def create_tags_probability(self):
if not os.path.isdir('tags_priori'):
os.makedirs('tags_priori')
tags_list = self.find_all_dir_tags()
for tag in tags_list:
# tag is like 'tag.txt',we need cut
tag_name = tag[:-4]
tag_count = self.collection.find({'tag': tag_name,
'content': {'$exists': True, '$ne': '无'},
}).count()
# the probability of tag in database
probability = float(tag_count) / float(self.find_all_articles().count())
line = tag_name + ' ' + repr(probability) + ' ' + repr(tag_count) + '\n'
probability_file = open('tags_priori/priori.txt', 'ab')
probability_file.write(line)
pass
def find_tags_p_probability(self, tag):
tag_posterior_file = open('tags_priori/priori.txt', 'rb')
for line in tag_posterior_file:
tags_list = line.split(' ')
if tags_list[0] == tag:
return float(tags_list[1])
return 0
navie_bayes/naive_bayes_classifier.py
import jieba
import jieba.analyse
from mongo_client.mongodb_client import MongodbClient
class NBClassifier(object):
def __init__(self):
self.db_client = MongodbClient()
pass
def classify_article(self, article):
extract_keywords = jieba.analyse.extract_tags(article,
topK=10)
tags_list = self.db_client.find_all_dir_tags()
# save probability of one tag
naive_bayes_probability = list()
# find all tags in tags_posterior to calculate priori probability
for tag in tags_list:
tags_file = open('tags_posterior/' + tag, 'r')
p = list()
for line in tags_file:
# split tag's key words
words_list = line.decode('utf-8').encode('utf-8').split(' ')
# save keywords in tag's file
w_list = list()
# save probability for keywords in tag's file
p_list = list()
for index, word in enumerate(words_list):
# the first word is group id, which not useful
if index == 0:
continue
elif index % 2 == 1:
w_list.append(word)
else:
p_list.append(word.strip('\n'))
p.append(self.calculate_tag_posterior(extract_keywords, w_list, p_list))
# calculate max probability with product posteriors and priori
naive_bayes_probability.append(max(p) * self.calculate_tag_priori(tag[:-4]))
# MAP(最大后验概率)
max_probability = max(naive_bayes_probability)
max_index = naive_bayes_probability.index(max_probability)
print naive_bayes_probability
print 'MAP:', max_probability, '该文章属于:', tags_list[max_index][:-4]
print '本来属于:', tags_list[10], naive_bayes_probability[10]
pass
def calculate_tag_priori(self, tag):
return self.db_client.find_tags_p_probability(tag)
def calculate_tag_posterior(self, keywords, w_list, p_list):
"""
here we use polynomial to calculate posterior
:param keywords: the words of test article
:param w_list: the keywords of sample's article
:param p_list: the probability of sample's article keywords
:return:product of tag's posterior(后验概率乘积)
!!!!
here you should pay attention for length of keywords and length of w_list
sometime they not the same length
"""
p = 1.0
c = 100 # suppose all of article's keywords is 100 number
for index, word in enumerate(keywords):
# type of str not to encode, unicode need to encode
if isinstance(word, unicode):
word = word.encode('utf-8')
if word in w_list:
# use polynomial algorithm to handle smooth (平滑处理)
p *= float((float(p_list[w_list.index(word)]) * c + 1)) / float(c + 2)
else:
p *= float((0.0 * c + 1)) / float(c + 2)
return p
pass
training/navie_bayes_training.py
import os
import jieba
import jieba.analyse
from mongo_client.mongodb_client import MongodbClient
class NBTraining(object):
def __init__(self):
self.db_client = MongodbClient()
pass
def create_tags_list(self):
self.db_client.create_tag_file()
def find_all_tags_list(self):
return self.db_client.find_all_dir_tags()
def find_all_articles_list(self):
return self.db_client.find_all_articles()
def add_stop_word(self):
"""
remove useless chinese words from article
this keywords txt is user-defined
:return:
"""
jieba.analyse.set_stop_words('stop_words.txt')
pass
def open_parallel_analyse(self, thread_count):
"""
open multi thread processing
:param thread_count:thread number
:return:
"""
jieba.enable_parallel(thread_count)
pass
def tf_idf_analyze_article(self):
"""
use TF-IDF model to analyse article to extract keywords
:return:
"""
article_list = self.find_all_articles_list()
tags_list = self.find_all_tags_list()
for article in article_list:
# only have content and tag needed to analyse
# here exist some situation, we need consider
if 'content' in article and 'tag' in article:
if article['content'] != u'无' or article['content'] != '':
article_name = (article['tag'] + '.txt').encode('utf-8')
tag_path = 'tags_posterior/' + article_name
# distinct the group id that have insert into tag file
group_id_list = self.find_all_tag_group_id(tag_path)
if self.exist_group_id(group_id_list, repr(article['group_id'])):
continue
# if the article dose not analyse, next extract the key words
# first, add user-defined chinese stop words
self.add_stop_word()
# second, extract at least 10 key words with weight
if article_name in tags_list:
# analyse content = title + content
content = (article['title'] + article['content']).encode('utf-8')
# start 4 threads in parallel
self.open_parallel_analyse(4)
extract_keywords = jieba.analyse.extract_tags(content,
topK=10,
withWeight=True)
article_keywords = list()
# group id is long type that need to translate into str
article_keywords.append(repr(article['group_id']))
for keyword in extract_keywords:
# word
article_keywords.append(keyword[0].encode('utf-8'))
# weight
article_keywords.append(repr(keyword[1]))
article_keywords_line = ' '.join(str(word) for word in article_keywords)
# print article_name
# print 'keywords line:', article_keywords_line
tags_file = open(tag_path, 'ab')
tags_file.write(article_keywords_line + '\n')
else:
print 'no tags_posterior for this ' + article_name
pass
def find_all_tag_group_id(self, path):
"""
get all group id in tags_posterior dir
:param path:
:return:group id list
"""
if not os.path.exists(path):
print 'the path for group id is not exist'
return
tags_file = open(path, 'rb')
group_id_list = list()
for line in tags_file:
words_list = line.split(' ')
if len(words_list) > 0:
group_id_list.append(words_list[0])
# print group_id_list
return group_id_list
pass
def exist_group_id(self, group_id_list, group_id):
"""
distinct the group id, instead of insert keywords line repeatedly
:param group_id_list:
:param group_id:
:return: True or False
"""
if group_id in group_id_list:
return True
else:
return False
pass
def clear_all_tags_file(self):
"""
remove all keywords in tags_posterior file,
manual clear
:return:
"""
tags_list = self.find_all_tags_list()
for tag in tags_list:
tag_file = open('tags_posterior/' + tag, 'wb')
tag_file.truncate()
training/stop_words.txt
停用词:意义不大的词,不需要统计的词语