贝叶斯分类方法学习三 python+jieba+mongodb实现朴素贝叶斯新闻文本自动分类

首先看一下工程目录:

这里写图片描述

 mongo_client目录下存放的是mongodb数据库的连接,以及数据的获取
 navie_bayes目录下存放的是朴素贝叶斯的实现
 tags_posterior目录下存放的是已经计算好样本的标签的后验概率
 tags_priori 目录下存放的是计算好的样本标签先验概率
 training目录下存放的样本训练的方法
 接下来简单的介绍一下各目录下的python实现:

mongo_client/mongodb_client.py

class MongodbClient(object):
    """
    connect mongodb
    """
    def __init__(self):
        db_client = MongoClient("localhost", 27017)
        db = db_client.toutiao
        self.collection = db.sample

    def find_all_db_tags(self):
        """
        collect all distinct tags_posterior in mongodb
        :return:
        """
        tags_list = self.collection.distinct('tag')

        if len(tags_list) == 0:
            print 'no tags_posterior in db, please checkout the connect of mongodb'

        return tags_list

    def find_all_dir_tags(self):
        """
        to find tags_posterior have existed in tags_posterior dir
        the format of tags_posterior like 'tag.txt' not 'tag'
        :return:
        """
        tags_list = os.listdir('tags_posterior/')

        return tags_list

    def create_tag_file(self):
        """
        create all kind of tags_posterior in this path tags_posterior/
        :return:
        """
        if not os.path.isdir('tags_posterior'):
            os.makedirs('tags_posterior')

        tags_list = self.find_all_db_tags()

        if len(tags_list) > 0:
            for tag in tags_list:
                filename = 'tags_posterior/' + tag.encode('utf-8') + '.txt'  # tag is unicode, but we declare all file is utf-8

                if os.path.exists(filename):
                    continue
                else:
                    f = open(filename, 'wb')

    def find_all_articles(self):
        """
        search all key-value dic result from mongodb
        :return:
        """
        article_list = self.collection.find({'content': {'$exists': True, '$ne': '无'}})
        if article_list.count == 0:
            print 'no data in database, please checkout the database'
        else:
            return article_list

    def create_tags_probability(self):

        if not os.path.isdir('tags_priori'):
            os.makedirs('tags_priori')

        tags_list = self.find_all_dir_tags()
        for tag in tags_list:
            # tag is like 'tag.txt',we need cut
            tag_name = tag[:-4]
            tag_count = self.collection.find({'tag': tag_name,
                                             'content': {'$exists': True, '$ne': '无'},
                                              }).count()

            # the probability of tag in database
            probability = float(tag_count) / float(self.find_all_articles().count())

            line = tag_name + ' ' + repr(probability) + ' ' + repr(tag_count) + '\n'

            probability_file = open('tags_priori/priori.txt', 'ab')

            probability_file.write(line)

            pass

    def find_tags_p_probability(self, tag):

        tag_posterior_file = open('tags_priori/priori.txt', 'rb')

        for line in tag_posterior_file:
            tags_list = line.split(' ')
            if tags_list[0] == tag:
                return float(tags_list[1])
        return 0
import jieba
import jieba.analyse
from mongo_client.mongodb_client import MongodbClient


class NBClassifier(object):

    def __init__(self):
        self.db_client = MongodbClient()
        pass

    def classify_article(self, article):

        extract_keywords = jieba.analyse.extract_tags(article,
                                                      topK=10)

        tags_list = self.db_client.find_all_dir_tags()

        # save probability of one tag
        naive_bayes_probability = list()

        # find all tags in tags_posterior to calculate priori probability
        for tag in tags_list:

            tags_file = open('tags_posterior/' + tag, 'r')

            p = list()

            for line in tags_file:
                # split tag's key words
                words_list = line.decode('utf-8').encode('utf-8').split(' ')
                # save keywords in tag's file
                w_list = list()
                # save probability for keywords in tag's file
                p_list = list()

                for index, word in enumerate(words_list):
                    # the first word is group id, which not useful
                    if index == 0:
                        continue
                    elif index % 2 == 1:
                        w_list.append(word)
                    else:
                        p_list.append(word.strip('\n'))

                p.append(self.calculate_tag_posterior(extract_keywords, w_list, p_list))

            # calculate max probability with product posteriors and priori
            naive_bayes_probability.append(max(p) * self.calculate_tag_priori(tag[:-4]))

        # MAP(最大后验概率)
        max_probability = max(naive_bayes_probability)
        max_index = naive_bayes_probability.index(max_probability)
        print naive_bayes_probability
        print 'MAP:', max_probability, '该文章属于:', tags_list[max_index][:-4]
        print '本来属于:', tags_list[10], naive_bayes_probability[10]

        pass

    def calculate_tag_priori(self, tag):
        return self.db_client.find_tags_p_probability(tag)

    def calculate_tag_posterior(self, keywords, w_list, p_list):
        """
        here we use polynomial to calculate posterior
        :param keywords: the words of test article
        :param w_list: the keywords of sample's article
        :param p_list: the probability of sample's article keywords
        :return:product of tag's posterior(后验概率乘积)

        !!!!
        here you should pay attention for length of keywords and length of w_list
        sometime they not the same length
        """

        p = 1.0
        c = 100   # suppose all of article's keywords is 100 number

        for index, word in enumerate(keywords):

            # type of str not to encode, unicode need to encode
            if isinstance(word, unicode):
                word = word.encode('utf-8')

            if word in w_list:
                # use polynomial algorithm to handle smooth (平滑处理)
                p *= float((float(p_list[w_list.index(word)]) * c + 1)) / float(c + 2)

            else:
                p *= float((0.0 * c + 1)) / float(c + 2)
        return p
    pass

training/navie_bayes_training.py

import os
import jieba
import jieba.analyse

from mongo_client.mongodb_client import MongodbClient


class NBTraining(object):
    def __init__(self):
        self.db_client = MongodbClient()
        pass

    def create_tags_list(self):
        self.db_client.create_tag_file()

    def find_all_tags_list(self):
        return self.db_client.find_all_dir_tags()

    def find_all_articles_list(self):
        return self.db_client.find_all_articles()

    def add_stop_word(self):
        """
        remove useless chinese words from article
        this keywords txt is user-defined
        :return:
        """
        jieba.analyse.set_stop_words('stop_words.txt')
        pass

    def open_parallel_analyse(self, thread_count):
        """
        open multi thread processing
        :param thread_count:thread number
        :return:
        """
        jieba.enable_parallel(thread_count)
        pass

    def tf_idf_analyze_article(self):
        """
        use TF-IDF model to analyse article to extract keywords
        :return:
        """
        article_list = self.find_all_articles_list()
        tags_list = self.find_all_tags_list()

        for article in article_list:

            # only have content and tag needed to analyse
            # here exist some situation, we need consider

            if 'content' in article and 'tag' in article:
                if article['content'] != u'无' or article['content'] != '':

                    article_name = (article['tag'] + '.txt').encode('utf-8')
                    tag_path = 'tags_posterior/' + article_name

                    # distinct the group id that have insert into tag file
                    group_id_list = self.find_all_tag_group_id(tag_path)
                    if self.exist_group_id(group_id_list, repr(article['group_id'])):
                        continue

                    # if the article dose not analyse, next extract the key words
                    # first, add user-defined chinese stop words
                    self.add_stop_word()

                    # second, extract at least 10 key words with weight
                    if article_name in tags_list:

                        # analyse content = title + content
                        content = (article['title'] + article['content']).encode('utf-8')

                        # start 4 threads in parallel
                        self.open_parallel_analyse(4)

                        extract_keywords = jieba.analyse.extract_tags(content,
                                                                      topK=10,
                                                                      withWeight=True)
                        article_keywords = list()

                        # group id is long type that need to translate into str
                        article_keywords.append(repr(article['group_id']))

                        for keyword in extract_keywords:
                            # word
                            article_keywords.append(keyword[0].encode('utf-8'))
                            # weight
                            article_keywords.append(repr(keyword[1]))

                        article_keywords_line = ' '.join(str(word) for word in article_keywords)
                        # print article_name
                        # print 'keywords line:', article_keywords_line
                        tags_file = open(tag_path, 'ab')
                        tags_file.write(article_keywords_line + '\n')

                    else:
                        print 'no tags_posterior for this ' + article_name
        pass

    def find_all_tag_group_id(self, path):
        """
        get all group id in tags_posterior dir
        :param path:
        :return:group id list
        """
        if not os.path.exists(path):
            print 'the path for group id is not exist'
            return

        tags_file = open(path, 'rb')

        group_id_list = list()

        for line in tags_file:
            words_list = line.split(' ')

            if len(words_list) > 0:
                group_id_list.append(words_list[0])
                # print group_id_list

        return group_id_list
        pass

    def exist_group_id(self, group_id_list, group_id):
        """
        distinct the group id, instead of insert keywords line repeatedly
        :param group_id_list:
        :param group_id:
        :return: True or False
        """
        if group_id in group_id_list:
            return True
        else:
            return False

    pass

    def clear_all_tags_file(self):
        """
        remove all keywords in tags_posterior file,
        manual clear
        :return:
        """
        tags_list = self.find_all_tags_list()
        for tag in tags_list:
            tag_file = open('tags_posterior/' + tag, 'wb')
            tag_file.truncate()

training/stop_words.txt

停用词:意义不大的词,不需要统计的词语
这里写图片描述

tags_posterior/

这里写图片描述

tags_priori/priori.txt

这里写图片描述

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值