TextRNN实现文本分类
任务介绍
给定一个如下的外卖评论的数据(1w条),训练模型分类好评和差评。
思路
给出的baseline为0.82(F1),方法是将语料中所有字拆开训练成300D的word2vec后,每一句的处理采用将所有字的向量相加取平均的方法得到句向量(300D),然后使用一个全连接层进行训练。
优化的思路自然就是从这个方法的缺点入手,主要的提升点有:
- 把所有的字拆开进行训练可以改进为先清洗语料(清除一些特殊符号、表情包等),并用分词后的词(word)进行word2vec
- 原来句向量的方法对句子中的信息可能会丢失,可以改用把所有词向量都输入网络中进行训练
- 改用更加复杂的模型如TextRNN、TextCNN等(最后使用了双层的LSTM)。另外的方法可以参考(https://www.cnblogs.com/sandwichnlp/p/11698996.html)
TextRNN
TextRNN的结构如图所示
可以看到主要就是把词向量输入一个BiLSTM(双向LSTM),将两个方向最后得到的向量concat,然后经过一个全连接层(FC),最后softmax。
LSTM结构解决了RNN中梯度消失(Vanishing gradients)和梯度爆炸(exploding gradients)的问题,原始的RNN对短期输入比较敏感,而LSTM增加了一个单元状态(cell state)来记忆长期的信息,通过输入门(input gate)、输出门(output gate)和遗忘门(forget gate)来决定LSTM最终的输出。另外还有一种变体叫GRU,其仅有更新门(update gate)和重置门(reset gate)且将单元状态和输出合并。(详见https://www.jianshu.com/p/19fd7206f070 )
但最终使用的网络却并非这一个结构,在测试中,这一网络最好的表现只有0.87(F1),最后取得0.90好成绩的网络是一个堆叠的uniLSTM网络,其结构如下
可以看到其中添加了Dropout来防止过拟合(详见 https://blog.csdn.net/program_developer/article/details/80737724),个人觉得因为训练数据偏少(1w条,划分训练集测试集后更少),因此优势较为明显。其主要为两个单向的LSTM加上Dropout和Softmax。
新手上路
看了这么多理论,就可以开始动手实践了。实际上动手实践才是比较难的内容(因为模型实现主要使用keras,对模型内部细节不必理解得十分透彻)。
除此之外环境也是一个十分让人头疼的事情,tensorflow1.x与cuda、cudnn等有十分令人头大的依赖关系,其配置过程也十分麻烦,这里一劳永逸使用了据说比较稳定的tf2.1。
环境
- Win10 1903 @ R7-3700x 16G RTX-2060
- Anaconda
- Python 3.7.6
- tensorflow-gpu==2.1
- CUDA 10.1
- CuDNN 7.6
ps. 建议保持CUDA和CuDNN的版本一直不变,直到tf指定更新的版本。另外建议你使用与我一样的环境配置。
过程一共有如下几步:数据预处理、模型构建、训练预测。
数据预处理
首先要清洗掉一些特殊符号、表情等,这一步比较简单,用正则表达式很方便就能完成。贴下代码
import re
pat = re.compile("[^u4e00-u9fa5\w,.,。!!?! ]")
res = []
with open("../data/input/train_input.txt", "r", encoding='utf-8') as f:
for line in f:
line = re.sub(pat, "", line)
res.append(line)
with open("../data/input/train_input_.txt", 'w', encoding='utf-8') as fw:
for line in res:
fw.write(line + "\n")
然后是词向量(word2vec),这里使用gensim的工具,处理完后保存。贴下代码
import pandas as pd
import jieba
import logging
from gensim.models.word2vec import Word2Vec
file = "data/train.csv"
data = pd.read_csv(file, sep='\t', encoding='UTF-8', header=None)
sentence = list(data[1])
def segment_sen(sen):
sen_list = jieba.lcut(sen)
return sen_list
sens_list = [segment_sen