本文内容是上文《LSTM基本原理及实践(上)》的实战篇:手动实现LSTM网络并对文本数据进行分类。
本文的内容和资料是来自2020 Google DevFest TensorFlow CodeLab活动中段清华老师的分享~
任务描述
- 基本任务:不使用TF官方的LSTM层,实现一个自定义的、继承自
tf.keras.Layer
的LSTM层 - 进阶任务:基于以上任务,实现一个直接以字符串为输入的文本分类模型
- 终极任务:基于以上模型,训练互联网情感分析数据集INEWS,并给出训练后的准确率。
数据集
-
各系统用户可以通过终端下载数据集,我以天池notebook环境为例:
pip install zh-dataset-inews
-
注意 天池notebook中下载完数据集需要重启kernel
-
查看数据集
from zh_dataset_inews import title_train,label_train,content_train
len(title_train)
输出:
for x,y in zip(title_train[:10],label_train[:10]):
print(x,y)
输出:
LSTM手动实现
前置知识:
- LSTM原理具体可以看《LSTM基本原理及实践(上)》
- LSTM相关公式如下:
- 根据官方手册,自定义实现一个模型主要是三步:
__init__
:进行于输入无关的初始化build
:当你知道输入的张量的维度时也可以将他们初始化call
:进行前向计算
1 基础任务
class CustomLSTM(tf.keras.layers.Layer):
"""
#LSTM's input:{batch_size,sequence_length,input_size}
#LSTM'S output 1:{batch_size,sequence_length,output_size}
# output 2:{batch_size,output_size}
"""
def __init__(self,output_size,return_sequences=False):
super(CustomLSTM,self).__init__()
self.output_size=output_size
self.return_sequences=return_sequences