首先,一定要吐槽一下,千万不要拿笔记本跑训练集很大的模型,不然真的能跑到吐血三升,而且电脑卡到宕机之后好不容易等到电脑能动了就是报错“内存不足”,简直能逼死我。每次跑程序都得把虚存开到最大,然后看着我的电脑在那卡卡卡,我都在想会不会把我电脑跑废了。当然结局是美好的,因为模型跑出来的审核结果正确率在94%以上,简直是惊喜到不行,嗯,基于这个理由,给我配工作站让我专门跑算法了,开心~
好了,言归正传,由于THUCTC模型对于文章审核方面的正确率很低,可能是由于切词结果和词的权重不合适的关系,导致我需要去寻找新的合适的模型和算法,最后选择了谷歌的gensim(主要针对自然语言的处理,里面也包括tf-idf)构建和训练词向量,神经网络选用了CNN卷积神经网络,分词选用的jieba分词,按照词性筛选需要的关键词构建词向量。比较核心的就是这几个了。还有一些对文档的处理,等到代码的时候就知道了,主要是筛除一些干扰的无用的词或标点,将文章统一长度等
话不多说,让我们来看看代码吧。
首先就是训练文档的准备,以什么形式不重要,只要后续能够读取做成模型的参数就行,我这里选用的是tensorflow的文本格式.tfRecords。如果是用普通文本的就自己读写文件就好了,我这里就以我用的tfRecords的存取来说明了。
代码如下,代码中也包括对于预测文档的准备及粗处理:
class loadData:
# =============================================================================
# 从配置文件中读取信息start
# =============================================================================
def __init__(self, base_dir = '.', path="database.conf"):
self.base_dir = base_dir
self.path = path
self.cf = configparser.ConfigParser()
self.cf.read(self.path)
def get_configInfo(self, field, key):
result = ""
try:
result = self.cf.get(field, key)
except:
result = ""
return result
def read_config(self, config_file_path="database.conf", field="db"):
cf = configparser.ConfigParser()
try:
cf.read(config_file_path)
db_host = cf.get(field, "db_host")
db_user = cf.get(field, "db_user")
db_pass = cf.get(field, "db_pass")
db_name = cf.get(field, "db_name")
db_charset = cf.get(field, "db_charset")
db_port = cf.get(field, "db_port")
except Exception as inst:
traceback.print_exc()
print(type(inst))
print(inst.args)
print(inst)
return db_host,db_user,db_pass,db_name,db_charset,db_port
# =============================================================================
# 从配置文件中读取信息end
# =============================================================================
# =============================================================================
# 数据库相关操作start
# =============================================================================
def connDB(self,config_file_path="database.conf"): # 连接数据库
db_host,db_user,db_pass,db_name,db_charset,db_port = self.read_config(config_file_path)
conn = pymysql.connect(host=db_host,user=db_user,password=db_pass,db=db_name,charset=db_charset,port=int(db_port))
cur = conn.cursor(cursor=pymysql.cursors.DictCursor)
return conn, cur
def exeUpdate(self,conn, cur, sql): # 更新,插入或删除操作
sta = cur.execute(sql)
conn.commit()
return sta
def exeQuery(self,cur, sql): # 查找操作
cur.execute(sql)
return cur.fetchall()
def connClose(self,conn, cur): # 关闭连接,释放资源
cur.close()
conn.close()
# =============================================================================
# 数据库相关操作end
# =============================================================================
# =============================================================================
# 从数据库中读取要训练的数据存储到TFRecords文件中,以待模型训练使用,start
# =============================================================================
def createTestFile(self,kind,date = 0):
tableName = {
0:"bidding",1:"article"}
description = {
0:"content",1:"details"}
map_type = {
0: 2, 1: 1}
conn,cur = self.connDB()
content_data = []
content_id = []
state = 1
mark = 0
pagesize = 100
loop = True
daytime = int(time.time()) - date*3600*24
while (loop):
sql = "select * from " + tableName[kind] + " where state = " + str(
state) + " and FROM_UNIXTIME(time,'%Y-%m-%d') = FROM_UNIXTIME(" + str(
daytime) + ",'%Y-%m-%d') order by id desc limit " + str(mark * pagesize) + " , " + str(pagesize)
print(sql)
data_cur = self.exeQuery(cur, sql)
if data_cur.__len__() > 0:
for data in data_cur:
text = data["title"] +"\t" + data[description[kind]]
content = self.seperate_line(self.clean_str(text))#对将要进行训练的训练集文件进行降噪处理
content_data.append(content)
content_id.append(data["id"])
else:
loop = False
mark = mark + 1
print('已经处理完第 %d 页' % (mark))
print("待分类数据生成成功!")
self.connClose(conn, cur)
return content_data, content_id
def clean_str(self, string):
string = re.sub('\s+', "", string)
r1 = u'[A-Za-z0-9’!"#$%&\'()*+,-./:;<=>?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
string = re.sub(r1, ' ', string)
return string.strip()
def seperate_line(self, line):
length = len(line)
line = line[0:int(length / 2)]
line = pseg.cut(line)
new_line = []
for words, flag in line:
if flag == 'nr' or flag == 'ns':
continue
if len(flag) == 0:
continue
if flag[0:1] != 'n' and flag != 'v':
# if flag[0:1] != 'n':
continue
new_line.append(words)
return ''.join([word +