【pytorch模型实现4】TextCNN

24 篇文章 2 订阅
24 篇文章 0 订阅

TextCNN模型实现

NLP模型代码github仓库:https://github.com/lyj157175/Models

import torch 
import torch.nn as nn 
import torch.nn.functional as F

class TextCNN(nn.Module):

    def __init__(self, config):
        super(TextCNN, self).__init__()
        self.max_seq_len = config.max_seq_len

        # 嵌入层
        if config.embedding_pretrain is not None:
            self.word_embed = nn.Embedding.from_pretrained(config.embedding_pretrain, freeze=False)
        else:
            self.word_embed = nn.Embedding(config.vocab_size, config.embedding_dim)
        
        # 卷积层
        # nn.Conv1d(in_channels, out_channels, kernel_size)
        # out: [batch_size, out_channels, n+2p-f/s+1]
        self.conv1d_1 = nn.Conv1d(config.embedding_dim, config.num_filters, config.filters[0])
        self.conv1d_2 = nn.Conv1d(config.embedding_dim, config.num_filters, config.filters[1])
        self.conv1d_3 = nn.Conv1d(config.embedding_dim, config.num_filters, config.filters[2])
        # 池化层
        self.Max_pool_1 = nn.MaxPool1d(self.max_seq_len - config.filters[0] + 1)
        self.Max_pool_2 = nn.MaxPool1d(self.max_seq_len - config.filters[1] + 1)
        self.Max_pool_3 = nn.MaxPool1d(self.max_seq_len - config.filters[2] + 1)
        # dropout层
        self.dropout = nn.Dropout(config.dropout)
        # 分类层
        self.fc = nn.Linear(len(config.filters) * config.num_filters, config.num_label)


    def forward(self, x):
        x = torch.Tensor(x).long()  # b, seq_len
        embed_x = self.word_embed(x)    # b, seq_len, embedding_dim
        embed_x = embed_x.transpose(2, 1).contiguous()  # b, embedding_dim, seq_len
        
        conv_x1 = F.relu(self.conv1d_1(embed_x))  # b, num_filters, n-f+1
        conv_x2 = F.relu(self.conv1d_2(embed_x))  # b, num_filters, n-f+1
        conv_x3 = F.relu(self.conv1d_3(embed_x))  # b, num_filters, n-f+1

        pool_x1 = self.Max_pool_1(conv_x1).squeeze()  # b, num_filters, 1 ---> b, num_filters
        pool_x2 = self.Max_pool_2(conv_x2).squeeze()  # b, num_filters, 1 ---> b, num_filters
        pool_x3 = self.Max_pool_3(conv_x3).squeeze()  # b, num_filters, 1 ---> b, num_filters

        out = torch.cat([pool_x1, pool_x2, pool_x3], dim=1)
        out = self.dropout(out)
        out = self.fc(out)
        return out 


class Config:
    def __init__(self):
        self.embedding_pretrain = None   # 是否使用预训练词向量
        self.vocab_size = 3000           # 此表大小
        self.embedding_dim = 100         # 词向量维度
        self.filters = [3, 4, 5]         # 卷积核尺寸
        self.num_filters = 100           # 每个尺寸卷积核数量
        self.max_seq_len = 50            # 最大句子长度
        self.dropout = 0.5               # dropout
        self.num_label = 2               # 类别数目
        

if __name__ == '__main__':
    config = Config()
    model = TextCNN(config)
    # print(model)
    x = torch.zeros([64, 50])
    out = model(x)
    print(out.shape)
    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值