RCNN模型也是用于文本分类的常用模型,其源论文为Recurrent Convolutional Neural Networks for Text Classification。
模型整体结构如下:
架构主要包括如下模块:
(1)通过双向RNN模型,得到每个token上下文的信息(隐层输出):
(2)通过隐层输出与原始embedding的拼接,得到扩展后的token信息;
(3)后面接于TextCNN的CNN、max-pooling和fc层,得到分类结果。
整个模型结构还是非常清晰的,下面给出pytorch的简单实现:
import torch
import torch.nn as nn
Config = {"vob_size": 100, # 字典尺寸
"ebd_size": 50, # 词嵌入维度
"hidden_size": 20, # 字典尺寸
"num_layer": 2,
"bidirectiion": True, # 双向
"drop":0.3, # dropout比例
"cnn_channel":100, # 1D-CNN的output_channel
"cnn_kernel": 3, # 1D-CNN的卷积核
"topk": 10, # cnn的output结果取top-k
"fc_hidden": 10, # 全连接层的隐藏层
"fc_cla": 4, # 全连接层的输出类别
}
class LSTM_pool(nn.Module):
def __init__(self):
super(LSTM_pool, self).__init__()
self.embedding = nn.Embedding(Config['vob_size'], Config['ebd_size'])
self.lstm = nn.LSTM(
input_size=Config['ebd_size'],
hidden_size=Config['hidden_size'],
num_layers=Config['num_layer'],
bidirectional=True,
batch_first=True,
dropout=Config['drop']
)
self.cnn = nn.Sequential(
nn.Conv1d(
in_channels=Config['hidden_size'] * 2 + Config['ebd_size'], # 词向量和output维度做concat
out_channels=Config['cnn_channel'],
kernel_size=Config['cnn_kernel']),
nn.BatchNorm1d(Config['cnn_channel']),
nn.ReLU(inplace=True),
nn.Conv1d(
in_channels=Config['cnn_channel'],
out_channels=Config['cnn_channel'],
kernel_size=Config['cnn_kernel']),
nn.BatchNorm1d(Config['cnn_channel']),
nn.ReLU(inplace=True)
)
self.fc = nn.Sequential(
nn.Linear(Config['topk'] * Config['cnn_channel'], Config['fc_hidden']), # 2为bidirectional的拼接结果
nn.BatchNorm1d(Config['fc_hidden']),
nn.ReLU(inplace=True),
nn.Linear(Config['fc_hidden'], Config['fc_cla'])
)
@staticmethod
def topk_pooling(x, k, dim):
index = torch.topk(x, k, dim=dim)[1]
return torch.gather(x, dim=dim, index=index)
def forward(self, x):
emb = self.embedding(x)
out, _ = self.lstm(emb) # (B, S, 2H)
out = torch.cat([emb, out], dim=-1) # (B, S, E) + (B, S, 2H) = (B, S, 2H+E)
out = out.permute((0, 2, 1)) # (B, 2H+E, S)
out = self.cnn(out) # (B, C, S-m)
x = self.topk_pooling(out, k=Config['topk'], dim=-1) # sequence_len方向取top2, (B, C, k)
x = x.view((x.size(0), -1)) # (B, C*k)
logits = self.fc(x)
return logits