第五课 CNN文本识别
构造输入数据
(batch_size,max_len,features_num)
卷积核大小一般设为(n*features_num)
,相当于每次选取n个词,使用不同数量,不同种类的卷积核总共得到
n
t
o
t
a
l
n_{total}
ntotal个特征图,为了解决不同种类的卷积核得到不同大小的特征图,会对特征图进行一次pooling操作
网络模型架构
class Model(nn.Module):
def __init__(self, config):
super(Model, self).__init__()
if config.embedding_pretrained is not None:
self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)
else:
self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)
self.convs = nn.ModuleList(
[nn.Conv2d(1, config.num_filters, (k, config.embed)) for k in config.filter_sizes])
# in_channels=1,out_channels=num_filters
self.dropout = nn.Dropout(config.dropout)
self.fc = nn.Linear(config.num_filters * len(config.filter_sizes), config.num_classes)
def conv_and_pool(self, x, conv):
x = F.relu(conv(x)).squeeze(3) #[128,256(卷积核数量),31(例,2*300大小的卷积核)]
x = F.max_pool1d(x, x.size(2)).squeeze(2) #[128,256]
return x
def forward(self, x):
#print (x[0].shape)
out = self.embedding(x[0]) #[128,32,300]
out = out.unsqueeze(1) #[128,1(颜色通道),32,300]
out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1) #[128,768]
out = self.dropout(out) #[128,768]
out = self.fc(out) #[128,10]
return out