Thucnews数据集
由于本地机器资源有限,所以仅拿出4W数据来做训练和测试
我把thucnews数据存到了本地mongo数据集上
如下所示(请忽略label 和lable 的字母拼写错误…)
项目目录结构
其中classify.py->分类主体文件
classify_text.txt->测试样例,自己搜集的数据
gpos-vocab->自己训练的词典库
stopwords->停用词典
text_util.py->对文本进行操作的工具
thucnews.json->读取mongo库数据,保存到本地的文件
代码如下
classify.py
#!/usr/bin/python
# -*- coding:utf-8 -*-
"""
@Description
@Author LiHao
@Date 2020/6/22
"""
import numpy as np
import tensorflow as tf
import pymongo
from text_classification.text_util import WordUtil
import os, json
wu = WordUtil(stopwords_path='stopwords', vocabulary_idf_path='gpos-vocab', use_pad=True)
thucnews_label_dict = {
0: '体育', 1: '娱乐', 2: '家居', 3: '彩票', 4: '房产',
5: '教育', 6: '时尚', 7: '时政', 8: '星座', 9: '游戏',
10: '社会', 11: '科技', 12: '股票', 13: '财经'}
def load_thucnews():
"""
获取mongo中的thucnews数据
:return:
"""
def get_doc(cursor):
x = []
y = []
for doc in cursor:
title_words = wu.cut_use_stopwords_vocab(doc['title'])
content_words = wu.cut_use_stopwords_vocab(doc['content'].replace('\r', '').replace('\n', ''))
title_words.extend(content_words)
words_id = wu.turn_words_2_id(title_words, 150)
label = int(doc['lableId'])
y_ = [0] * 14
y_[label] = 1
x.append(words_id)
y.append(y_)
return x, y
data = None
if not os.path.exists('thucnews.json'):
client = pymongo.MongoClient('localhost', 27017)
db = client.get_database('corpus')
coll = db.get_collection('thucnews')
train_x, train_y = get_doc(coll.find({
'type': 0}))
dev_x, dev_y = get_doc(coll.find({
'type': 1}))
test_x, test_y = get_doc(coll.find({
'type': 2}))
client.close()
data = {
'train': {
'x': train_x, 'y': train_y}, 'test': {
'x': test_x, 'y': test_y},
'dev': {
'x': dev_x, 'y': dev_y}}
json.dump(data, open('thucnews.json', 'w', encoding='utf-8'))
else:
data = json.load(open('thucnews.json', 'r', encoding='utf-8'))
print('load thucnews success')
return data
class TextCNNModel():
def __init__(self, model_path=None, num_label=10, embedding_size=50, max_len=150<