看这篇论文前,建议先了解一下policy gradient RL,就更很容易理解论文思想了。
论文:《Learning Structured Representation for Text Classification via Reinforcement Learning》
代码:http://coai.cs.tsinghua.edu.cn/publications/
一、论文原理
这篇论文在文本分类任务中,应用了policy gradient强化学习的方法,来得到更好的句子结构化表征(ID-LSTM model保留有用单词,删除无用的单词如"a","the"等;HS-LSTM model将整个序列划分为多个短语结构),从而得到更好的文本分类效果。
二、模型结构
模型分为三个部分:
策略网络(PNet)、结构化表示结构(两个LSTM Module)、分类网络(CNet).
这里的两个LSTM Module是分别训练的,PNet决定Information Distilled LSTM (ID-LSTM)中是否保留当前单词,action为{Retain, Delete};PNet决定Hierarchically Structured LSTM (HS-LSTM) 中word-level lstm当前单词是否是短语结束位置/短语中,action为{Inside, End},再将判断的短语输入phrase-level lstm得到序列结构化特征。 下面会详细介绍。
-
策略网络(PNet)根据 结构化表示模型(LSTM Model) 中每一个step的输入和上一层隐层状态决定当前采取的action (即是否保留/删除该单词、该单词是否在短语中/结束处)。
-
在完成一序列action后,结构化表示模型(LSTM Model) 输出最终的文本特征。
-
分类网络(CNet)对输入的文本特征分类,根据分类结果对