一般情况下,训练神经网络时,pytorch的测试部分语句都会写在 with torch.no_grad() 的代码段下,以关闭tensor的自动求导、计算梯度功能,节省显存和运算时间。
但是有时会希望临时允许计算梯度,比如作者是在用pytorch_grad_cam生成神经网络的可解释热图时,它需要通过计算梯度生成,否则会报“element 0 of tensors does not require grad and does not have a grad_fn”的错误。
百度没有搜到怎么临时允许计算梯度,干脆去看了一下torch.no_grad的源码:
class no_grad(_DecoratorContextManager):
... # 一堆注释
def __init__(self):
if not torch._jit_internal.is_scripting():
super().__init__()
self.prev = False
def __enter__(self):
self.prev = torch.is_grad_enabled()
torch.set_grad_enabled(False)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch.set_grad_enabled(self.prev)
可以看到这个类是通过在__enter__和__exit__里调用torch.set_grad_enabled实现关闭梯度计算的,所以可以对应地写一个临时允许计算梯度的类:
class TemporaryGrad(object):
def __enter__(self):
self.prev = torch.is_grad_enabled()
torch.set_grad_enabled(True)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch.set_grad_enabled(self.prev)
使用的时候,和torch.no_grad一样,把代码放在with TemporaryGrad():里就行了,例如:
def getCAM(model: Module,
image_tensor: torch.Tensor,
transform: Compose,
target_type: int,
device: torch.device = torch.device('cpu')):
with TemporaryGrad():
target_layers = [model.layer4[-1]]
cam = GradCAM(model=model, target_layers=target_layers)
targets = [ClassifierOutputTarget(target_type)]
grayscale_cam = cam(input_tensor=image_tensor.unsqueeze(0).to(device), targets=targets)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(np.array(transformConvert(image_tensor, transform)) / 255, grayscale_cam, use_rgb=True)
return visualization
上面的函数是用来获得CAM热图的,有需要的朋友也可以参考拿去用,注意image_tensor去掉了冗余的第一维,即形如(3,224,224)而不是(1,3,224,224)。