高版本transformers-4.24中的坑

最近遇到一个很奇怪的BUG,好早之前写的一个Bert文本分类模型,拿给别人用的时候,发现不灵了,原本90多的acc,什么都没修改,再测一次发现只剩30多了,检查了一番之后,很快我发现他的transformers版本是4.24,而我一直用的是4.9,没有更新。

于是我试着分析问题出在哪里,然后就遇到了这个坑。首先这是我模型的基础结构,很简单,就是一个Encoder模型加一层分类器:

class BertClassifier(torch.nn.Module):
    def __init__(self, bert_model, num_classes):
        super(BertClassifier, self).__init__()
        self.bert = bert_model
        self.dropout = torch.nn.Dropout(0.2)
        self.dense = torch.nn.Linear(768, num_classes)
        
    def forward(
        self,
        input_ids=None,
        token_type_ids=None,
        attention_mask=None,
        labels=None,
    ):
        bert_out = self.bert(input_ids, token_type_ids, attention_mask, output_attentions=False)
        # print(list(self.bert.encoder.layer[0].attention.self.query.parameters()))
        # print(bert_out)
        sequence_output = bert_out.last_hidden_state
        print(sequence_output)
        sequence_output = self.dropout(sequence_output)
        pool_output = torch.mean(sequence_output, axis=1)
        
        logits = self.dense(pool_output)
        # print(logits)
        
        loss = None
        loss_fct = torch.nn.CrossEntropyLoss()
        
        if labels is not None:
            # labels = label.long()
            loss = loss_fct(logits, labels.view(-1))
        
        return loss if loss is not None else logits

为了分析问题出在哪里,我把类里的代码全都拿出来,逐行运行,发现最终的logits和正确的logits(在4.9版本的环境里执行的结果)是一致的,这就很奇怪了,但是我实例化模型,再用模型forward出来的结果却是错误的:

# 这个结果计算出来是对的
sequence_output = bert_cls_model.bert(**inputs).last_hidden_state
sequence_output = bert_cls_model.dropout(sequence_output)
pool_output = torch.mean(sequence_output, axis=1)
logits = bert_cls_model.dense(pool_output)
print(logits)

# 这样计算出来是错的
logits = bert_cls_model(**inputs)
print(logits)

于是我又在模型类的定义里打印了各个阶段的结果,如上第一段代码中的print,发现从bert_out的打印结果来看全都是错的。

更进一步地,为了确认是不是模型加载权重的时候出现了问题(比如加载权重后的模型被重新初始化了),我又在模型定义代码里打印了模型的参数值,确认参数值也是没有问题的。这就让我感到有些匪夷所思了。

我又按照同样的对比方法,在模型里边打印一次,单独拿出来打印一次,试着找出问题所在,这次是从一开始embedding开始,结果发现在模型内部和外部打印embedding的结果是一致的:

# 这样打印的结果是正确的
bert_cls_model.bert.embeddings(input_ids=inputs['input_ids'], token_type_ids=inputs['token_type_ids'])

# 在模型的forward方法里打印embedding的结果同样是正确的

更奇怪的是,我将embedding的结果输入给encoder手动计算,出来的sequence_out就变成正确的了:

class BertClassifier(torch.nn.Module):
    def __init__(self, bert_model, num_classes):
        super(BertClassifier, self).__init__()
        self.bert = bert_model
        self.dropout = torch.nn.Dropout(0.2)
        self.dense = torch.nn.Linear(768, num_classes)
        
    def forward(
        self,
        input_ids=None,
        token_type_ids=None,
        attention_mask=None,
        labels=None,
    ):
    	# 直接调用self.bert计算出来结果是错误的
        # bert_out = self.bert(input_ids, token_type_ids, attention_mask, output_attentions=False)

		# 手动以此调用embedding和encoder,就算出来的结果就是正确的了
        embedding_res = self.bert.embeddings(input_ids=input_ids, token_type_ids=token_type_ids)
        encoder_out = self.bert.encoder(embedding_res)
        sequence_output = encoder_out[0]

        sequence_output = self.dropout(sequence_output)
        pool_output = torch.mean(sequence_output, axis=1)
        
        logits = self.dense(pool_output)
        # print(logits)
        
        loss = None
        loss_fct = torch.nn.CrossEntropyLoss()
        
        if labels is not None:
            # labels = label.long()
            loss = loss_fct(logits, labels.view(-1))
        
        return loss if loss is not None else logits

最后我又额外检查了一遍两个版本源码的差别,也没有发现什么端倪,感觉修改的地方都是些写法的差异,不应该有能够造成这个问题的地方。

解决的话,目前就是把transformers的版本降下来,或者像最后这样手动执行计算,还没有发现真正出问题的地方在哪里,如果有哪位也遇到这个问题并且有效解决了的话,还请在评论区指出,谢谢。

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值