Bert模型冻结指定参数进行训练

      由于bert模型具有12层,参数量达一亿,bert模型做微调有的时候就需要只训练部分参数,那么就需要把其他的参数冻结掉,固定住,又能微调bert模型,还能提高模型训练的效率。这个就需要用到parameter的requires_grad的属性,来冻结和放开参数。

首先我们看看bert模型的具体参数有那些:

bert.embeddings.word_embeddings.weight torch.Size([21128, 768])
bert.embeddings.position_embeddings.weight torch.Size([512, 768])
bert.embeddings.token_type_embeddings.weight torch.Size([2, 768])
bert.embeddings.LayerNorm.weight torch.Size([768])
bert.embeddings.LayerNorm.bias torch.Size([768])

bert.encoder.layer.0.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.query.bias torch.Size([768])
bert.encoder.layer.0.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.key.bias torch.Size([768])
bert.encoder.layer.0.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.value.bias torch.Size([768])
bert.encoder.layer.0.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.0.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.0.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.0.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.0.output.dense.bias torch.Size([768])
bert.encoder.layer.0.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.0.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.1.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.self.query.bias torch.Size([768])
bert.encoder.layer.1.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.self.key.bias torch.Size([768])
bert.encoder.layer.1.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.self.value.bias torch.Size([768])
bert.encoder.layer.1.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.1.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.1.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.1.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.1.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.1.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.1.output.dense.bias torch.Size([768])
bert.encoder.layer.1.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.1.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.2.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.2.attention.self.query.bias torch.Size([768])
bert.encoder.layer.2.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.2.attention.self.key.bias torch.Size([768])
bert.encoder.layer.2.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.2.attention.self.value.bias torch.Size([768])
bert.encoder.layer.2.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.2.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.2.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.2.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.2.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.2.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.2.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.2.output.dense.bias torch.Size([768])
bert.encoder.layer.2.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.2.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.3.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.3.attention.self.query.bias torch.Size([768])
bert.encoder.layer.3.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.3.attention.self.key.bias torch.Size([768])
bert.encoder.layer.3.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.3.attention.self.value.bias torch.Size([768])
bert.encoder.layer.3.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.3.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.3.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.3.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.3.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.3.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.3.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.3.output.dense.bias torch.Size([768])
bert.encoder.layer.3.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.3.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.4.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.4.attention.self.query.bias torch.Size([768])
bert.encoder.layer.4.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.4.attention.self.key.bias torch.Size([768])
bert.encoder.layer.4.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.4.attention.self.value.bias torch.Size([768])
bert.encoder.layer.4.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.4.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.4.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.4.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.4.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.4.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.4.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.4.output.dense.bias torch.Size([768])
bert.encoder.layer.4.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.4.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.5.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.5.attention.self.query.bias torch.Size([768])
bert.encoder.layer.5.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.5.attention.self.key.bias torch.Size([768])
bert.encoder.layer.5.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.5.attention.self.value.bias torch.Size([768])
bert.encoder.layer.5.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.5.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.5.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.5.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.5.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.5.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.5.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.5.output.dense.bias torch.Size([768])
bert.encoder.layer.5.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.5.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.6.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.6.attention.self.query.bias torch.Size([768])
bert.encoder.layer.6.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.6.attention.self.key.bias torch.Size([768])
bert.encoder.layer.6.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.6.attention.self.value.bias torch.Size([768])
bert.encoder.layer.6.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.6.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.6.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.6.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.6.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.6.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.6.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.6.output.dense.bias torch.Size([768])
bert.encoder.layer.6.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.6.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.7.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.7.attention.self.query.bias torch.Size([768])
bert.encoder.layer.7.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.7.attention.self.key.bias torch.Size([768])
bert.encoder.layer.7.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.7.attention.self.value.bias torch.Size([768])
bert.encoder.layer.7.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.7.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.7.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.7.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.7.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.7.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.7.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.7.output.dense.bias torch.Size([768])
bert.encoder.layer.7.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.7.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.8.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.8.attention.self.query.bias torch.Size([768])
bert.encoder.layer.8.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.8.attention.self.key.bias torch.Size([768])
bert.encoder.layer.8.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.8.attention.self.value.bias torch.Size([768])
bert.encoder.layer.8.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.8.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.8.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.8.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.8.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.8.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.8.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.8.output.dense.bias torch.Size([768])
bert.encoder.layer.8.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.8.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.9.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.9.attention.self.query.bias torch.Size([768])
bert.encoder.layer.9.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.9.attention.self.key.bias torch.Size([768])
bert.encoder.layer.9.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.9.attention.self.value.bias torch.Size([768])
bert.encoder.layer.9.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.9.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.9.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.9.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.9.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.9.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.9.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.9.output.dense.bias torch.Size([768])
bert.encoder.layer.9.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.9.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.10.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.10.attention.self.query.bias torch.Size([768])
bert.encoder.layer.10.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.10.attention.self.key.bias torch.Size([768])
bert.encoder.layer.10.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.10.attention.self.value.bias torch.Size([768])
bert.encoder.layer.10.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.10.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.10.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.10.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.10.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.10.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.10.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.10.output.dense.bias torch.Size([768])
bert.encoder.layer.10.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.10.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.11.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.11.attention.self.query.bias torch.Size([768])
bert.encoder.layer.11.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.11.attention.self.key.bias torch.Size([768])
bert.encoder.layer.11.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.11.attention.self.value.bias torch.Size([768])
bert.encoder.layer.11.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.11.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.11.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.11.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.11.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.11.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.11.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.11.output.dense.bias torch.Size([768])
bert.encoder.layer.11.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.11.output.LayerNorm.bias torch.Size([768])

bert.pooler.dense.weight torch.Size([768, 768])
bert.pooler.dense.bias torch.Size([768])
out.weight torch.Size([2, 768])
out.bias torch.Size([2])

比如说现在我们要放开第11和12层以及bert.pooler和out层参数,冻结其他的参数,怎么实现呢?

pytorch中有 model.named_parameters() 和 requires_grad,直接写一个遍历然后设置就好。具体实现代码:

import torch.nn as nn
from transformers import BertModel
import torch

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained('pretrained_models/Chinese-BERT-wwm')
        self.out = nn.Linear(768,2)

    def forward(self):
        out = self.bert()
        return out

if __name__ == '__main__':
    model = Model()
    unfreeze_layers = ['layer.10','layer.11','bert.pooler','out.']

    for name, param in model.named_parameters():
        print(name,param.size())

    print("*"*30)
    print('\n')

    for name ,param in model.named_parameters():
        param.requires_grad = False
        for ele in unfreeze_layers:
            if ele in name:
                param.requires_grad = True
                break
    #验证一下
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name,param.size())

    #过滤掉requires_grad = False的参数
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.00001)

最后在做训练的时候,优化器中一定要添加过滤器filter把requires_grad = False的参数过滤掉,在训练的时候,不会更新这些参数。

 

参考文章:

Pytorch 如何精确的冻结我想冻结的预训练模型的某一层,有什么命令吗?

  • 27
    点赞
  • 96
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值