TextCNN
Yoon Kim在论文(2014 EMNLP) Convolutional Neural Networks for Sentence Classification提出TextCNN。
将卷积神经网络CNN应用到文本分类任务,利用多个不同size的kernel来提取句子中的关键信息(类似于多窗口大小的ngram),从而能够更好地捕捉局部相关性。
TextCNN 通过卷积核提取文本序列当中的信息,通过max pooling选取最符合的特征。
核心代码
class TextCNN(nn.Module):
def __init__(self, base_size, kernels):
super(TextCNN, self).__init__()
self.base_size = base_size
self.kernels = kernels
self.convs = []
for k in kernels:
conv = nn.Conv2d(1,1,(k,base_size))
self.convs.append(self.try_cuda(conv))
def forward(self, data, length):
result = []
for i in range(len(self.kernels)):
k_size = self.kernels[i]
k_length = length - k_size + 1
conv = self.convs[i]
pool = nn.MaxPool2d(k_length)
target = conv(data)
target = pool(target)
target = target.view(-1,1)
result.append(target)
return torch.cat(result, dim=1)
def try_cuda(self, target):
if torch.cuda.is_available():
return target.cuda()
else:
return target
完整代码
https://github.com/rookitkitlee/TextClassifier/tree/master/TextCNN