CNN的使用

在文本处理中,CNN模型的使用。
具体可以参考博文:
https://blog.csdn.net/sunny_xsc1994/article/details/82969867

一个过滤器的结果,直接做linear多分类

import torch.nn as nn
class E2EModel(nn.Module):
    def __init__(self):
        super(E2EModel, self).__init__()
        self.encode = BertModel.from_pretrained(model_name)
        self.conv = nn.Conv1d(
            in_channels=768,
            out_channels=64,
            kernel_size=(64,),
            stride=1,  # 水平和竖直方向滑动1
            bias=True,
            padding_mode='zeros'
        )
        self.pool=nn.MaxPool1d(128-64+1)
        self.drop=nn.Dropout(0.5)
        self.linear=nn.Linear(64,1)
        # self.linear2=nn.Linear(10,1)
    def forward(self, inputs_id, att_mask):
        x = self.encode(inputs_id, att_mask)[0]  # B*L*768
        input = x.permute(0, 2, 1)
        out = self.conv(input)
        out=self.pool(out)
        out = out.view(-1, out.size(1))#B*128
        out=self.drop(out)
        out=self.linear(out)
        
        # out=self.linear2(out)
        out=torch.sigmoid(out)
        return out

model=E2EModel()

多个不同的过滤器,得到的结果cat在一起。


import torch.nn as nn
window_sizes=[3,5]
class TextCNN(nn.Module):
    def __init__(self):
        super(TextCNN, self).__init__()
        self.is_training = True
        self.dropout_rate = 0.5

        self.embedding = BertModel.from_pretrained(model_name)
        self.convs = nn.ModuleList([
            nn.Sequential(nn.Conv1d(in_channels=768,
                                    out_channels=50,
                                    kernel_size=h),
                          #                              nn.BatchNorm1d(num_features=config.feature_size),
                          nn.ReLU(),
                          nn.MaxPool1d(kernel_size=128 - h + 1))
            for h in window_sizes
        ])
        self.fc = nn.Linear(in_features=50 * len(window_sizes),
                            out_features=1)
        # if os.path.exists(config.embedding_path) and config.is_training and config.is_pretrain:
        #     print("Loading pretrain embedding...")
        #     self.embedding.weight.data.copy_(torch.from_numpy(np.load(config.embedding_path)))

    def forward(self,  inputs_id, att_mask):
        embed_x = self.embedding( inputs_id, att_mask)[0]

        # print('embed size 1',embed_x.size())  # 32*35*256
        # batch_size x text_len x embedding_size  -> batch_size x embedding_size x text_len
        embed_x = embed_x.permute(0, 2, 1)
        # print('embed size 2',embed_x.size())  # 32*256*35
        out = [conv(embed_x) for conv in self.convs]  # out[i]:batch_size x feature_size*1
        # for o in out:
        #    print('o',o.size())  # 32*100*1
        out = torch.cat(out, dim=1)  # 对应第二个维度(行)拼接起来,比如说5*2*1,5*3*1的拼接变成5*5*1
        # print(out.size(1)) # 32*400*1
        out = out.view(-1, out.size(1))
        # print(out.size())  # 32*400
        # if not self.use_element:
        out = F.dropout(input=out, p=self.dropout_rate)
        out = self.fc(out)
        out=torch.sigmoid(out)
        return out

model=TextCNN()

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

YingJingh

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值