pytorch 参数冻结、加载、拓展

本博客主要介绍了

1.如何冻结模型的部分参数不更新(不参与训练)

2. 部分冻结的模型,加载时如何加载。

3. 如何拓展模型,在现有模型的基础上,拓展训练参数,进行增量训练

class MyModel(nn.Module):

    def __init__(self, num_user, pretrain):
        super().__init__()
        self.bert = BertModel.from_pretrained(pretrain)
        # 冻结self.bert的参数
        for param in self.bert.parameters():
            param.requires_grad = False

        self.text_pooler = nn.Linear(768, 768)

        self.user_emb = nn.Embedding(num_user, 768)
        self.linear1 = nn.Linear(768, 768*3)
        self.drop = nn.Dropout(0.3)
        self.linear2 = nn.Linear(768*3, 768)
    
    def extend_user(self, n):
        """拓展参数"""
        num_user = self.user_emb.num_embeddings
        new_users = n + num_user
        extended_user_emb = nn.Embedding(new_users, 768)
        extended_user_emb.weight.data[:num_user] = self.user_emb.weight.data
        self.user_emb = extended_user_emb

    def save(self, path):
        """保存的时候也不要保存bert的参数,只保存我训练的参数,模型体积大大减小"""
        model_state_dict = self.state_dict()
        non_bert_params = {key: value for key, value in model_state_dict.items() if 'bert' not in key}
        torch.save(non_bert_params, path)

    def load(self, path):
        """这里我们只加载了非bert的参数,而bert的参数在init方法中加载了"""
        file_path = os.path.join(path, 'model.bin')
        # 加载之前保存的非bert参数
        non_bert_params = torch.load(file_path)

        # 将非bert参数加载到模型中
        self.load_state_dict(non_bert_params, strict=False)

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch中,要冻结某层参数,即保持其权重在训练过程中不发生更新。这种操作通常在迁移学习或固定特定的层的场景下使用,以便保留已经学到的知识。 要冻结某层参数,可以通过以下步骤实现: 1. 首先,加载模型并查看模型的结构。通过打印模型就可以看到每一层的名称以及对应的索引。 2. 在训练之前,确定需要冻结的层。可以通过模型的参数名称或索引来定位到具体的层。 3. 使用`requires_grad_()`函数来冻结参数,将需要冻结的层的`requires_grad`属性设置为False。这样,在反向传播过程中,这些参数的梯度就不会进行更新了。 4. 在训练过程中,只对其他未冻结的层进行梯度更新。 下面是一个简单的示例代码,演示如何冻结某层参数: ```python import torch import torch.nn as nn # 加载模型并创建优化器 model = torchvision.models.resnet18(pretrained=True) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 查看模型结构 print(model) # 冻结某层参数 # 可以通过模型的named_parameters()函数获取每一层的名称和参数 # 这里以冻结ResNet的第4个卷积层参数为例 for name, param in model.named_parameters(): if 'layer4' in name: # 可根据具体需求来决定冻结哪些层 param.requires_grad_(False) # 训练过程 for inputs, labels in dataloader: outputs = model(inputs) loss = loss_func(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() ``` 通过以上步骤,我们可以实现冻结某层参数的操作。这样,在训练过程中,被冻结的层的参数将不会更新,从而保持其固定的权重。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值