Pytorch 冻结参数更新的三种方法

在Pytorch中,默认情况下,所有设置requires_grad=True的张量都能跟踪它们的梯度计算历史并支持梯度计算。然而,在某些情况下,我们不需要这样做,例如,当我们已经训练了模型,只想将其应用于一些输入数据时,即我们只想通过网络进行前向计算。此时,这个禁用梯度跟踪操作就显得很重要,即冻结某个变量或模块的参数更新。下面介绍三种方式实现这个操作。
第一种方法:requires_grad_(False)冻结

import torch
import torch.nn as nn
class my_model(nn.Module):
    def __init__(self):
        super(my_model, self).__init__()
        self.l1 = nn.Linear(3,3).requires_grad_(False)
        self.l2 = nn.Linear(3,3)
    def forward(self, x):
        out = self.l1(x) +self.l2(x) 
        return out
model = my_model()
y=torch.rand(6,3)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
for i in range(2):
    data = torch.randn(6,3)
    out = model(x)
    loss=nn.functional.mse_loss(y,out)
    optimi
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值