库文件修改
Tips:多输入和多损失在计算传递过程均会报错
对应python环境下,修改库文件
"xxx\envs\pytorch_study\Lib\site-packages\pytorch_grad_cam\base_cam.py"
LOSS损失:
代码84行添加
if self.uses_gradients:
self.model.zero_grad()
loss = sum([target(output)
for target, output in zip(targets, outputs)])
loss = loss.mean() # 添加位置
oss.backward(retain_graph=True)
网络多输入:
103行修改
将input_tensor 改为多输入中的一个输入
eg:
def get_target_width_height(self,input_tensor: torch.Tensor) -> Tuple[int, int]:
input_tensor = input_tensor["image1.0"] # 修改位置
width, height = input_tensor.size(-1), input_tensor.size(-2)
return width, height