Pytorch:在with torch.no_grad()的代码段里临时允许计算梯度

一般情况下,训练神经网络时,pytorch的测试部分语句都会写在 with torch.no_grad() 的代码段下,以关闭tensor的自动求导、计算梯度功能,节省显存和运算时间。
但是有时会希望临时允许计算梯度,比如作者是在用pytorch_grad_cam生成神经网络的可解释热图时,它需要通过计算梯度生成,否则会报“element 0 of tensors does not require grad and does not have a grad_fn”的错误。
百度没有搜到怎么临时允许计算梯度,干脆去看了一下torch.no_grad的源码:

class no_grad(_DecoratorContextManager):
    ... # 一堆注释
    def __init__(self):
        if not torch._jit_internal.is_scripting():
            super().__init__()
        self.prev = False

    def __enter__(self):
        self.prev = torch.is_grad_enabled()
        torch.set_grad_enabled(False)

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        torch.set_grad_enabled(self.prev)

可以看到这个类是通过在__enter__和__exit__里调用torch.set_grad_enabled实现关闭梯度计算的,所以可以对应地写一个临时允许计算梯度的类:

class TemporaryGrad(object):
    def __enter__(self):
        self.prev = torch.is_grad_enabled()
        torch.set_grad_enabled(True)

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        torch.set_grad_enabled(self.prev)

使用的时候,和torch.no_grad一样,把代码放在with TemporaryGrad():里就行了,例如:

def getCAM(model: Module,
           image_tensor: torch.Tensor,
           transform: Compose,
           target_type: int,
           device: torch.device = torch.device('cpu')):
    with TemporaryGrad():
        target_layers = [model.layer4[-1]]
        cam = GradCAM(model=model, target_layers=target_layers)
        targets = [ClassifierOutputTarget(target_type)]
        grayscale_cam = cam(input_tensor=image_tensor.unsqueeze(0).to(device), targets=targets)
        grayscale_cam = grayscale_cam[0, :]
        visualization = show_cam_on_image(np.array(transformConvert(image_tensor, transform)) / 255, grayscale_cam, use_rgb=True)
    return visualization

上面的函数是用来获得CAM热图的,有需要的朋友也可以参考拿去用,注意image_tensor去掉了冗余的第一维,即形如(3,224,224)而不是(1,3,224,224)。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值