网络结构
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/d3ef402d9bb507baa46eeda3f78abdab.jpeg)
代码实现
class TextCNN(nn.Module):
def __init__(self,
config:TCNNConfig,
char_size = 5000, pinyin_size=5000):
super(TextCNN, self).__init__()
self.learning_rate = config.learning_rate
self.keep_dropout = config.keep_dropout
self.sequence_length = config.sequence_length
self.char_embedding_size = config.char_embedding_size
self.pinyin_embedding_size = config.pinyin_embedding_size
self.filter_list = config.filter_list
self.out_channels = config.out_channels
self.l2_reg_lambda = config.l2_reg_lambda
self.model_dir = config.model_dir
self.data_save_frequency = config.data_save_frequency
self.model_save_frequency = config.model_save_frequency
self.char_size = char_size
self.pinyin_size = pinyin_size
self.embedding_size = self.char_embedding_size
self.total_filters_size = self.out_channels * len(self.filter_list)
self.build_model()
def build_model(self):
self.char_embeddings = nn.Embedding(self.char_size, self.char_embedding_size)
self.char_embeddings.weight.requires_grad = True
self.pinyin_embeddings = nn.Embedding(self.pinyin_size, self.pinyin_embedding_size)
self.pinyin_embeddings.weight.requires_grad = True
self.conv_list = nn.ModuleList()
conv_list = [nn.Sequential(
nn.Conv1d(self.embedding_size, self.out_channels, filter_size),
nn.BatchNorm1d(self.out_channels),
nn.ReLU(inplace=True)
) for filter_size in self.filter_list]
self.conv_lists_layer = nn.ModuleList(conv_list)
self.output_layer = nn.Sequential(
nn.Dropout(self.keep_dropout),
nn.Linear(self.total_filters_size, self.total_filters_size),
nn.ReLU(inplace=True),
nn.Linear(self.total_filters_size, 2)
)
def forward(self, char_id, pinyin_id):
pooled_outputs = []
sen_char = self.char_embeddings(char_id)
sen_pinyin = self.pinyin_embeddings(pinyin_id)
sen_embed = torch.cat((sen_char, sen_pinyin), dim=1)
sen_embed = sen_embed.permute(0, 2, 1)
for conv in self.conv_lists_layer:
conv_output = conv(sen_embed)
max_polling_output = torch.max(conv_output, dim=2)
pooled_outputs.append(max_polling_output[0])
total_pool = torch.cat(pooled_outputs, 1)
flatten_pool = total_pool.view(-1, self.total_filters_size)
fc_output = self.output_layer(flatten_pool)
return fc_output
实践总结
- TextCNN优点是模型简单、训练和预测的速度快;缺点是超参(主要是卷积核列表)不易确定,效果不如BILSTM+ATTENTION;
相关文章
- 使用BILSTM+ATTENTION模型进行文本分类(链接);
- 使用Transformer模型进行文本分类(链接)