pytorch的梯度传递

本文详细介绍了PyTorch中requires_grad属性如何影响梯度传递,包括三种基本情况:当一个输入张量需要记录梯度,另一个不需要时;两个输入张量都需记录梯度;以及两者都不需要。此外,还展示了如何通过设置requires_grad=False冻结骨干网络,以在微调模型时仅更新特定层。最后,讨论了在网络中,即使输入数据不记录梯度,由于网络参数记录梯度,导致中间和最终输出仍会记录梯度。
摘要由CSDN通过智能技术生成

1.requires_grad的传递

requires_gard 是tensor的一个属性,requires_gard=False表示不记录梯度,requires_gard=True表示记录张量的梯度。

每次的计算抽象为张量 A 与 B 做数学运算得到张量 C,C 是否记录梯度取决于 A 和 B的情况。
在这里插入图片描述

1.1三种情况下的梯度传递

  • A.requires_gard=False B.requires_gard=True ⇒ C.requires_gard=True
    A = torch.tensor([1., 2., 3.], requires_grad=True)
    B = torch.tensor([4., 5., 6.], requires_grad=False)
    C = A + B
    C.requires_grad
    ---------------------------------------------------------------------------------
    True
    
  • A.requires_gard=True B.requires_gard=False ⇒ C.requires_gard=True
    A = torch.tensor([1., 2., 3.], requires_grad=False)
    B = torch.tensor([4., 5., 6.], requires_grad=True)
    C = A + B
    C.requires_grad
    ----------------------------------------------------------------------------------
    True
    
  • A.requires_gard=False B.requires_gard=False ⇒ C.requires_gard=False
    A = torch.tensor([1., 2., 3.], requires_grad=False)
    B = torch.tensor([4., 5., 6.], requires_grad=False)
    C = A + B
    C.requires_grad
    -----------------------------------------------------------------------------------
    False
    

由此可见,只有当输入都不需要记录梯度时,后续计算的张量才不记录梯度只要有一个输入张量计算梯度,后续的张量均需要记录梯度

1.2利用requires_grad=False冻结骨干网络

# 获得pytorch的预训练模型
model = torchvision.models.resnet18(pretrained=True)
# 冻结model的梯度计算
for p in model.parameters():
    p.requires_grad = False
# 替换最上层的fc
model.fc = torch.nn.Linear(512, 100)
# 新创建的liner层默认requires_grad=True
optmizer = torch.optim.SGD(model.fc.parameters(), lr=0.001)

1.3网络中的数据是记录梯度的

model = torchvision.models.resnet18(pretrained=True)
inputs = torch.randn(1, 3, 128, 128)
inputs.requires_grad
model(inputs).requires_grad
--------------------------------------------------------------------
False
True

虽然输入网络的tensor inputs 是不记录梯度的(requires_grad=False),但是网络的参数记录梯度,导致中间层的输出数据和最终的输出数据的requires_grad=True。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值