TextCNN实战-详细维度变化展示
前言
TextCNN(2014 kim) 我觉得是入门NLP的一个很好的技术。相对RNN,CNN的可并行能力可以让它实验起来很快。
一、原理概述
TextCNN使用多个卷积核提取特征,利用特征做文本分类的模型。大体步骤如下:
1.每一个卷积核:其中一维一定是和词向量维度相同,另一维表示卷积核一次看几个单词。
2.最大池化层:卷积核扫过文本向量后,要经过一个最大池化层,最终输出一个1个标量作为特征。
3.多个卷积核:一个卷积核输出一个标量特征。几个卷积核就生成几个标量特征。
4.正则化:随机将一些特征设定为0,防止过拟合。
5.非线性变换:将得到特征做一个线性变换再套上一个非线性函数,输出分类的结果。
二、Pytorch 代码实战(详细展示维度变化)
梳理所有步骤维度变化如下:
- 超参数:词典大小vocab_len = 1300,句子长度seq_len=10,词向量大小embedding_size=256,分类维度n_class=3, batch_size=32
- 词嵌入层 embedding:[32,10] ➡ [32,10,256] (unsqueeze) ➡ [32, 1, 10, 256]
- 卷积层-1:[32,1,10, 256] (经过 100个[3,256]的卷积核) ➡ [32,100,8,1] (squeeze) ➡ [32,100,8]
- 卷积层-2:[32,1,10, 256] (经过 100个[4,256]的卷积核) ➡ [32,100,7,1] (squeeze) ➡ [32,100,7]
- 卷积层-3:[32,1,10, 256] (经过 100个[5,256]的卷积核) ➡ [32,100,6,1] (squeeze) ➡ [32,100,6]
注意卷积层1、2、3是并行的。 - 池化层-1:[32,100,8] ➡ [32,100,1]
- 池化层-2:[32,100,7] ➡ [32,100,1]
- 池化层-3:[32,100,6] ➡ [32,100,1]
- 拼接所有特征:3个[32,100,1] (cat) ➡ [32,300,1] (squeeze) ➡ [32,300]
- dropout层:[32,300] ➡ [32,300]
- 全连接层:[32,300] ➡ [32,3]
代码如下:
import torch
from torch import nn
BAICH_SIZE = 32
VOCAB_LEN = 1300
EMBEDDING_SIZE = 256
N_CLASS = 3
class TextCNN(nn.Module):
def __init__(self, vocab_len, embedding_size, n_class):
super().__init__()
self.embedding = nn.Embedding(vocab_len, embedding_size)
self.cnn1 = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=[3, embedding_size])
self.cnn2 = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=[4, embedding_size])
self.cnn3 = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=[5, embedding_size])
self.max_pool1 = nn.MaxPool1d(kernel_size=8)
self.max_pool2 = nn.MaxPool1d(kernel_size=7)
self.max_pool3 = nn.MaxPool1d(kernel_size=6)
self.drop_out = nn.Dropout(0.2)
self.full_connect = nn.Linear(300, n_class)
def forward(self, x):
embedding = self.embedding(x)
embedding = embedding.unsqueeze(1)
cnn1_out0 = self.cnn1(embedding)
cnn2_out0 = self.cnn2(embedding)
cnn3_out0 = self.cnn3(embedding)
cnn1_out = cnn1_out0.squeeze(-1)
cnn2_out = cnn2_out0.squeeze(-1)
cnn3_out = cnn3_out0.squeeze(-1)
out1 = self.max_pool1(cnn1_out)
out2 = self.max_pool2(cnn2_out)
out3 = self.max_pool3(cnn3_out)
out = torch.cat([out1, out2, out3], dim=1).squeeze(-1)
out_drop = self.drop_out(out)
out_final = self.full_connect(out_drop)
return out
model = TextCNN(VOCAB_LEN, EMBEDDING_SIZE, N_CLASS)
总结
我实战代码的时候维度转换是最容易出错的,希望能对大家有所帮助。