【文本分类】RCNN模型

RCNN模型也是用于文本分类的常用模型,其源论文为Recurrent Convolutional Neural Networks for Text Classification

模型整体结构如下:
在这里插入图片描述
架构主要包括如下模块:
(1)通过双向RNN模型,得到每个token上下文的信息(隐层输出):
(2)通过隐层输出与原始embedding的拼接,得到扩展后的token信息;
(3)后面接于TextCNN的CNN、max-pooling和fc层,得到分类结果。

整个模型结构还是非常清晰的,下面给出pytorch的简单实现:


import torch
import torch.nn as nn


Config = {"vob_size": 100,         # 字典尺寸
          "ebd_size": 50,            # 词嵌入维度
          "hidden_size": 20,         # 字典尺寸
          "num_layer": 2,
          "bidirectiion": True,   # 双向
          "drop":0.3,      # dropout比例
          "cnn_channel":100,   # 1D-CNN的output_channel
          "cnn_kernel": 3,    # 1D-CNN的卷积核
          "topk": 10,  # cnn的output结果取top-k
          "fc_hidden": 10,  # 全连接层的隐藏层
          "fc_cla": 4,  # 全连接层的输出类别
          }


class LSTM_pool(nn.Module):
    def __init__(self):
        super(LSTM_pool, self).__init__()
        self.embedding = nn.Embedding(Config['vob_size'], Config['ebd_size'])
        self.lstm = nn.LSTM(
            input_size=Config['ebd_size'],
            hidden_size=Config['hidden_size'],
            num_layers=Config['num_layer'],
            bidirectional=True,
            batch_first=True,
            dropout=Config['drop']
        )

        self.cnn = nn.Sequential(
            nn.Conv1d(
                in_channels=Config['hidden_size'] * 2 + Config['ebd_size'],  # 词向量和output维度做concat
                out_channels=Config['cnn_channel'],
                kernel_size=Config['cnn_kernel']),
            nn.BatchNorm1d(Config['cnn_channel']),
            nn.ReLU(inplace=True),

            nn.Conv1d(
                in_channels=Config['cnn_channel'],
                out_channels=Config['cnn_channel'],
                kernel_size=Config['cnn_kernel']),

            nn.BatchNorm1d(Config['cnn_channel']),
            nn.ReLU(inplace=True)

        )

        self.fc = nn.Sequential(
            nn.Linear(Config['topk'] * Config['cnn_channel'], Config['fc_hidden']),   # 2为bidirectional的拼接结果
            nn.BatchNorm1d(Config['fc_hidden']),
            nn.ReLU(inplace=True),

            nn.Linear(Config['fc_hidden'], Config['fc_cla'])

        )

    @staticmethod
    def topk_pooling(x, k, dim):
        index = torch.topk(x, k, dim=dim)[1]
        return torch.gather(x, dim=dim, index=index)

    def forward(self, x):
        emb = self.embedding(x)
        out, _ = self.lstm(emb)    # (B, S, 2H)
        out = torch.cat([emb, out], dim=-1)   # (B, S, E) + (B, S, 2H) = (B, S, 2H+E)
        out = out.permute((0, 2, 1))    # (B, 2H+E, S)
        out = self.cnn(out)    # (B, C, S-m)
        x = self.topk_pooling(out, k=Config['topk'], dim=-1)   # sequence_len方向取top2,  (B, C, k)
        x = x.view((x.size(0), -1))    # (B, C*k)
        logits = self.fc(x)
        return logits
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值