最近在使用pytorch中的一维卷积来对文本进行处理,进行文本分类任务,查阅了网上相关的博客还有api这里做一个总结。
一维卷积,顾名思义就是在一维空间上进行卷积,通常用来处理时序的数据,卷积的过程如下图。
进行卷积的数据形状为[batch_size,seq_len,embedding_dim],经过卷积以后变成了[batch_size,out_channels,sql_len-kernel_size+1]的形状,在卷积的时候是在最后一个维度进行的所以需要对数据进行点处理,具体如代码所示。
import torch.nn as nn
import torch
data = torch.randn(4,5,8)# [batch_size,seq_len,embedding_dim)
con1d = nn.Conv1d(in_channels=8,out_channels=16,kernel_size=2)
data = torch.transpose(data,2,1)# 同 data.permute(0,2,1)
con1d_out= con1d(data)#[batch_size,out_chanels,seq_len-kernel_size+1] ->[4, 16, 4]
print(con1d_out.shape)
print(con1d_out)
这里采用了tranpose对dim=1,dim=2的维度数据进行了交换,同样的使用