pytorch Bert模型冻结指定层参数进行训练

由于bert模型具有12层,参数量达一亿,bert模型做微调有的时候就需要只训练部分参数,那么就需要把其他的参数冻结掉,固定住,又能微调bert模型,还能提高模型训练的效率。这个就需要用到parameter的requires_grad的属性,来冻结和放开参数。

首先我们看看bert模型的具体参数有那些:

bert.embeddings.word_embeddings.weight torch.Size([21128, 768])
bert.embeddings.position_embeddings.weight torch.Size([512, 768])
bert.embeddings.token_type_embeddings.weight torch.Size([2, 768])
bert.embeddings.LayerNorm.weight torch.Size([768])
bert.embeddings.LayerNorm.bias torch.Size([768])
 
bert.encoder.layer.0.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.query.bias torch.Size([768])
bert.encoder.layer.0.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.key.bias torch.Size([768])
bert.encoder.layer.0.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.value.bias torch.Size([768])
bert.encoder.layer.0.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.0.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.0.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.0.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.0.output.dense.bias torch.Size([768])
bert.encoder.layer.0.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.0.output.LayerNorm.bias torch.Size([768])
 
bert.encoder.layer.1.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.self.query.bias torch.Size([768])
bert.encoder.layer.1.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.self.key.bias torch.Size([768])
bert.encoder.layer.1.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.self.value.bias torch.Size([768])
bert.encoder.layer.1.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.1.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.1.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.1.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.1.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.1.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.1.output.dense.bias torch.Size([768])
bert.encoder.layer.1.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.1.output.LayerNorm.bias torch.Size([768])
 
bert.encoder.layer.2.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.2.attention.self.query.bias torch.Size([768])
bert.encoder.layer.2.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.2.attention.self.key.bias torch.Size([768])
bert.encoder.layer.2.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.2.attention.self.value.bias torch.Size([768])
bert.encoder.layer.2.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.2.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.2.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.2.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.2.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.2.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.2.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.2.output.dense.bias torch.Size([768])
bert.encoder.layer.2.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.2.output.LayerNorm.bias torch.Size([768])
 
bert.encoder.layer.3.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.3.attention.self.query.bias torch.Size([768])
bert.encoder.layer.3.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.3.attention.self.key.bias torch.Size([768])
bert.encoder.layer.3.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.3.attention.self.value.bias torch.Size([768])
bert.encoder.layer.3.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.3.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.3.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.3.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.3.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.3.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.3.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.3.output.dense.bias torch.Size([768])
bert.encoder.layer.3.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.3.output.LayerNorm.bias torch.Size([768])
 
bert.encoder.layer.4.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.4.attention.self.query.bias torch.Size([768])
bert.encoder.layer.4.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.4.attention.self.key.bias torch.Size([768])
bert.encoder.layer.4.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.4.attention.self.value.bias torch.Size([768])
bert.encoder.layer.4.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.4.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.4.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.4.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.4.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.4.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.4.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.4.output.dense.bias torch.Size([768])
bert.encoder.layer.4.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.4.output.LayerNorm.bias torch.Size([768])
 
bert.encoder.layer.5.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.5.attention.self.query.bias torch.Size([768])
bert.encoder.layer.5.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.5.attention.self.key.bias torch.Size([768])
bert.encoder.layer.5.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.5.attention.self.value.bias torch.Size([768])
bert.encoder.layer.5.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.5.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.5.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.5.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.5.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.5.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.5.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.5.output.dense.bias torch.Size([768])
bert.encoder.layer.5.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.5.output.LayerNorm.bias torch.Size([768])
 
bert.encoder.layer.6.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.6.attention.self.query.bias torch.Size([768])
bert.encoder.layer.6.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.6.attention.self.key.bias torch.Size([768])
bert.encoder.layer.6.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.6.attention.self.value.bias torch.Size([768])
bert.encoder.layer.6.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.6.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.6.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.6.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.6.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.6.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.6.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.6.output.dense.bias torch.Size([768])
bert.encoder.layer.6.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.6.output.LayerNorm.bias torch.Size([768])
 
bert.encoder.layer.7.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.7.attention.self.query.bias torch.Size([768])
bert.encoder.layer.7.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.7.attention.self.key.bias torch.Size([768])
bert.encoder.layer.7.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.7.attention.self.value.bias torch.Size([768])
bert.encoder.layer.7.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.7.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.7.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.7.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.7.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.7.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.7.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.7.output.dense.bias torch.Size([768])
bert.encoder.layer.7.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.7.output.LayerNorm.bias torch.Size([768])
 
bert.encoder.layer.8.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.8.attention.self.query.bias torch.Size([768])
bert.encoder.layer.8.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.8.attention.self.key.bias torch.Size([768])
bert.encoder.layer.8.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.8.attention.self.value.bias torch.Size([768])
bert.encoder.layer.8.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.8.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.8.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.8.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.8.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.8.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.8.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.8.output.dense.bias torch.Size([768])
bert.encoder.layer.8.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.8.output.LayerNorm.bias torch.Size([768])
 
bert.encoder.layer.9.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.9.attention.self.query.bias torch.Size([768])
bert.encoder.layer.9.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.9.attention.self.key.bias torch.Size([768])
bert.encoder.layer.9.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.9.attention.self.value.bias torch.Size([768])
bert.encoder.layer.9.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.9.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.9.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.9.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.9.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.9.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.9.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.9.output.dense.bias torch.Size([768])
bert.encoder.layer.9.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.9.output.LayerNorm.bias torch.Size([768])
 
bert.encoder.layer.10.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.10.attention.self.query.bias torch.Size([768])
bert.encoder.layer.10.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.10.attention.self.key.bias torch.Size([768])
bert.encoder.layer.10.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.10.attention.self.value.bias torch.Size([768])
bert.encoder.layer.10.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.10.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.10.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.10.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.10.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.10.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.10.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.10.output.dense.bias torch.Size([768])
bert.encoder.layer.10.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.10.output.LayerNorm.bias torch.Size([768])
 
bert.encoder.layer.11.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.11.attention.self.query.bias torch.Size([768])
bert.encoder.layer.11.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.11.attention.self.key.bias torch.Size([768])
bert.encoder.layer.11.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.11.attention.self.value.bias torch.Size([768])
bert.encoder.layer.11.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.11.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.11.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.11.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.11.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.11.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.11.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.11.output.dense.bias torch.Size([768])
bert.encoder.layer.11.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.11.output.LayerNorm.bias torch.Size([768])
 
bert.pooler.dense.weight torch.Size([768, 768])
bert.pooler.dense.bias torch.Size([768])
out.weight torch.Size([2, 768])
out.bias torch.Size([2])

比如说现在我们要放开第11和12层以及bert.pooler和out层参数,冻结其他的参数,怎么实现呢?

pytorch中有 model.named_parameters() 和 requires_grad,直接写一个遍历然后设置就好。具体实现代码:

import torch.nn as nn
from transformers import BertModel
import torch
 
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained('pretrained_models/Chinese-BERT-wwm')
        self.out = nn.Linear(768,2)
 
    def forward(self):
        out = self.bert()
        return out
 
if __name__ == '__main__':
    model = Model()
    unfreeze_layers = ['layer.10','layer.11','bert.pooler','out.']
 
    for name, param in model.named_parameters():
        print(name,param.size())
 
    print("*"*30)
    print('\n')
 
    for name ,param in model.named_parameters():
        param.requires_grad = False
        for ele in unfreeze_layers:
            if ele in name:
                param.requires_grad = True
                break
    #验证一下
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name,param.size())
 
    #过滤掉requires_grad = False的参数
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.00001)

最后在做训练的时候,优化器中一定要添加过滤器filter把requires_grad = False的参数过滤掉,在训练的时候,不会更新这些参数。

参考文献:
Bert模型冻结指定参数进行训练

  • 7
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
PyTorch冻结某些参数训练可以通过以下步骤实现: 1. 加载 ResNet50 预训练模型: ```python import torchvision.models as models resnet50 = models.resnet50(pretrained=True) ``` 2. 冻结指定参数: ```python for name, param in resnet50.named_parameters(): if 'layer3' not in name and 'layer4' not in name: param.requires_grad = False ``` 上述代码中,我们遍历 ResNet50 模型的所有参数,如果参数名中不包含 "layer3" 和 "layer4",则将其 requires_grad 属性设置为 False,即冻结参数。 3. 将模型放到 GPU 上: ```python device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') resnet50 = resnet50.to(device) ``` 4. 定义优化器和损失函数: ```python import torch.optim as optim criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(filter(lambda p: p.requires_grad, resnet50.parameters()), lr=0.001, momentum=0.9) ``` 上述代码中,我们只优化 requires_grad 属性为 True 的参数,即未冻结参数。 5. 训练模型: ```python for epoch in range(num_epochs): for i, (inputs, labels) in enumerate(train_loader): inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = resnet50(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() ``` 上述代码中,我们使用 DataLoader 加载数据,并将输入和标签放到 GPU 上进行训练。由于部分参数冻结,因此反向传播时只会更新未冻结参数
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

柒然

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值