torch中的retain graph、detach

本文介绍了PyTorch中retain_graph参数的作用,它用于在需要进行多次反向传播时保留计算图。此外,还解释了detach操作的使用场景,即在不希望某个变量参与后续反向传播时,可以detach该变量,使其变成常量。通过一个ContentLoss类的例子,展示了detach如何确保目标变量在计算损失时不影响梯度计算。
摘要由CSDN通过智能技术生成

1. retain graph

通常计算完backward之后会将计算图free掉以节省空间,但是如果需要做两次backwark,且有共用的graph时,需要先retain graph然后再进行backward。

Pytorch中retain_graph参数的作用 - Oldpan的个人博客前言 在pytorch神经网络迁移的官方教程中有这样一个损失层函数(具体看这里提供0.3.0版中文链接:https://oldpan.me/archives/pytorch-neural-transfer)。 class ContentLoss(nn.Module): def __init__(self, target, weight):……https://oldpan.me/archives/pytorch-retain_graph-work

2. detach

假设一个变量在graph中位置1时参与反向传播,需要更新权重;在位置2时不需要反向传播,只是当做一个常量来参与运算。此时,在位置2计算之前需要进行detach操作。

class ContentLoss(nn.Module):

    def __init__(self, target, weight):
        super(ContentLoss, self).__init__()
        self.target = target.detach() * weight
        # 因为这里只是需要target这个数值,这个数值是一种状态,不计入计算树中。
        # 这里单纯将其当做常量对待,因此用了detach则在backward中计算梯度时不对target之前所在的计    
          算图存在任何影响。
        self.weight = weight
        self.criterion = nn.MSELoss()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值