在Pytorch中,默认情况下,所有设置requires_grad=True的张量都能跟踪它们的梯度计算历史并支持梯度计算。然而,在某些情况下,我们不需要这样做,例如,当我们已经训练了模型,只想将其应用于一些输入数据时,即我们只想通过网络进行前向计算。此时,这个禁用梯度跟踪操作就显得很重要,即冻结某个变量或模块的参数更新。下面介绍三种方式实现这个操作。
第一种方法:requires_grad_(False)冻结
import torch
import torch.nn as nn
class my_model(nn.Module):
def __init__(self):
super(my_model, self).__init__()
self.l1 = nn.Linear(3,3).requires_grad_(False)
self.l2 = nn.Linear(3,3)
def forward(self, x):
out = self.l1(x) +self.l2(x)
return out
model = my_model()
y=torch.rand(6,3)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
for i in range(2):
data = torch.randn(6,3)
out = model(x)
loss=nn.functional.mse_loss(y,out)
optimi