今天调用CAM生成热图时碰到了一个问题具体是这行代码报错(完整的heatmap代码最后放出来):cam = GradCAM(model=Local_Branch_model, target_layer=target_layer, use_cuda=False)
每次执行就报下面这个错误。
修改第65行代码,发现传入的output变成了元组格式,因此获取其内容替换一下就行了,出现原因还不太清楚。
下面放上自定义模型读取检查点,并生成热图的代码:
# copy from https://github.com/jacobgil/pytorch-grad-cam/blob/master/cam.py
# 对单个图像可视化
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, \
deprocess_image, \
preprocess_image
from torchvision.models import resnet50
import cv2
import numpy as np
import os
from utils.Densenet import Densenet121_AG, Fusion_Branch, resnet_CA_instance
CKPT_PATH_G = 'E:/chuan/CODE/AG-CNN-master/CAD_models/AG_CNN_Global_epoch_69.pkl'
CKPT_PATH_L = 'E:/chuan/CODE/AG-CNN-master/CAD_models/AG_CNN_Local_epoch_69.pkl'
CKPT_PATH_F = 'E:/chuan/CODE/AG-CNN-master/CAD_models/AG_CNN_Fusion_epoch_69.pkl'
import torch
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# 1.加载模型
# model = resnet50(pretrained