构建基于BERT微调的多标签分类模型

实现继承BERT预训练模型的分类任务类

import torch.nn as nn
from transformers import BertPreTrainedModel, BertModel, BertConfig

# 构建基于BERT的微调模型类
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()

        # 导入参数设置对象
        model_config = BertConfig.from_pretrained(config.bert_path,     
                                                  num_labels=config.num_classes)
        # 导入基于bert-base-chinese的预训练模型
        self.bert = BertModel.from_pretrained(config.bert_path, config=model_config)

        # 此处用于调节是否将BERT纳入微调训练, 建议数据量+算力充足的情况下置为True
        # 如果设置为False, 则保持整个BERT网络参数不变, 微调仅仅针对最后的全连接层进行训练
        for param in self.bert.parameters():
            param.requires_grad = True

        # 全连接层的出口维度, 取决于具体的任务
        self.fc = nn.Linear(config.hidden_size, config.num_classes)

    def forward(self, x):
        # x[0]是输入的具体文本信息
        context = x[0]
        # x[1]是经过tokenizer处理后返回的attention mask张量
        # mask的尺寸size和输入相同, padding部分用0遮掩, 比如[1, 1, 1, 0, 0]
        mask = x[1]
        # x[2]是字符类型id
        token_type_ids = x[2]
        # 利用BERT模型得到输出张量, 并且只保留BertPooler的输出, 即第一个字符CLS对应的输出张量
        _, pooled = self.bert(context, attention_mask=mask, token_type_ids=token_type_id)
        # 再利用微调网络进一步提取特征, 并利用全连接层对特征张量进行维度变换
        out = self.fc(pooled)
        return out

对BERT模型的参数执行微调

        展示BERT模型中的参数命名:

class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)

        # 将BERT中所有的参数层名字打印出来
        for name, param in self.bert.named_parameters():
            print(name)

        self.fc = nn.Linear(config.hidden_size, config.num_classes)

针对BERT模型中的embedding层, 让其中的参数不参与微调

class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)

        # 希望锁定embeddings层的参数, 不参与更新
        for name, param in self.bert.embeddings.named_parameters():
            print(name)
            param.requires_grad = False

        self.fc = nn.Linear(config.hidden_size, config.num_classes)

BERT中的全连接层, 让其中的weight参数不参与微调

class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)

        # 希望将全连接层中的.weight部分参数锁定
        for name, param in self.bert.named_parameters():
            if name.endswith('weight'):
                print(name)
                param.requires_grad = False

        self.fc = nn.Linear(config.hidden_size, config.num_classes)

BERT中指定的若干层, 让其中的参数不参与微调

class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)

        # 封闭BERT中的第1, 3, 5层参数, 不参与微调
        index_array = [1, 3, 5]
        for name, param in self.bert.named_parameters():
            new_x = name.split('.')[2]
            if new_x in index_array:
                print(name)
                param.requires_grad = False

        self.fc = nn.Linear(config.hidden_size, config.num_classes)

  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值