2021-08-01


前言

TextCNN(2014 kim) 我觉得是入门NLP的一个很好的技术。相对RNN,CNN的可并行能力可以让它实验起来很快。

一、原理概述

TextCNN使用多个卷积核提取特征,利用特征做文本分类的模型。大体步骤如下:

1.每一个卷积核:其中一维一定是和词向量维度相同,另一维表示卷积核一次看几个单词。
2.最大池化层:卷积核扫过文本向量后,要经过一个最大池化层,最终输出一个1个标量作为特征。
3.多个卷积核:一个卷积核输出一个标量特征。几个卷积核就生成几个标量特征。
4.正则化:随机将一些特征设定为0,防止过拟合。
5.非线性变换:将得到特征做一个线性变换再套上一个非线性函数,输出分类的结果。

二、Pytorch 代码实战(详细展示维度变化)

梳理所有步骤维度变化如下:

  1. 超参数:词典大小vocab_len = 1300,句子长度seq_len=10,词向量大小embedding_size=256,分类维度n_class=3, batch_size=32
  2. 词嵌入层 embedding:[32,10] ➡ [32,10,256] (unsqueeze) ➡ [32, 1, 10, 256]
  3. 卷积层-1:[32,1,10, 256] (经过 100个[3,256]的卷积核) ➡ [32,100,8,1] (squeeze) ➡ [32,100,8]
  4. 卷积层-2:[32,1,10, 256] (经过 100个[4,256]的卷积核) ➡ [32,100,7,1] (squeeze) ➡ [32,100,7]
  5. 卷积层-3:[32,1,10, 256] (经过 100个[5,256]的卷积核) ➡ [32,100,6,1] (squeeze) ➡ [32,100,6]
    注意卷积层1、2、3是并行的。
  6. 池化层-1:[32,100,8] ➡ [32,100,1]
  7. 池化层-2:[32,100,7] ➡ [32,100,1]
  8. 池化层-3:[32,100,6] ➡ [32,100,1]
  9. 拼接所有特征:3个[32,100,1] (cat) ➡ [32,300,1] (squeeze) ➡ [32,300]
  10. dropout层:[32,300] ➡ [32,300]
  11. 全连接层:[32,300] ➡ [32,3]

点击查看每个步骤的维度图片

代码如下:

import torch
from torch import nn
BAICH_SIZE = 32
VOCAB_LEN = 1300
EMBEDDING_SIZE = 256
N_CLASS = 3

class TextCNN(nn.Module):
    def __init__(self, vocab_len, embedding_size, n_class):
        super().__init__()

        self.embedding = nn.Embedding(vocab_len, embedding_size)

        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=[3, embedding_size])
        self.cnn2 = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=[4, embedding_size])
        self.cnn3 = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=[5, embedding_size])

        self.max_pool1 = nn.MaxPool1d(kernel_size=8)
        self.max_pool2 = nn.MaxPool1d(kernel_size=7)
        self.max_pool3 = nn.MaxPool1d(kernel_size=6)

        self.drop_out = nn.Dropout(0.2)
        self.full_connect = nn.Linear(300, n_class)

    def forward(self, x):
        embedding = self.embedding(x)
        embedding = embedding.unsqueeze(1)
        
        cnn1_out0 = self.cnn1(embedding)
        cnn2_out0 = self.cnn2(embedding)
        cnn3_out0 = self.cnn3(embedding)

        cnn1_out = cnn1_out0.squeeze(-1)
        cnn2_out = cnn2_out0.squeeze(-1)
        cnn3_out = cnn3_out0.squeeze(-1)

        out1 = self.max_pool1(cnn1_out)
        out2 = self.max_pool2(cnn2_out)
        out3 = self.max_pool3(cnn3_out)

        out = torch.cat([out1, out2, out3], dim=1).squeeze(-1)

        out_drop = self.drop_out(out)
        out_final = self.full_connect(out_drop)
        return out
        
model = TextCNN(VOCAB_LEN, EMBEDDING_SIZE, N_CLASS)

总结

我实战代码的时候维度转换是最容易出错的,希望能对大家有所帮助。

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值