pytorch CAM得到热力图

pytorch CAM得到热力图

用法

根据网上的代码改成自己的模型的,并且改成了对整个文件夹的图片挨个生成
直接复制代码就行了,不用去下载包
修改main函数里的东西就行了,注释了的

下面这图是我自己的,跟想象的重点果然不一样啊

在这里插入图片描述

代码

增加了删除 hook 函数,不然在训练中查看每轮的热力图的话会内存泄漏,让本不富裕的内存雪上加霜
调用的话直接调用函数名就行了,改一下
finalconv_name = 'layer4' # 改成自己的层
classes = {0:'neg', 1:'pos'} # 改成自己的对应关系
params = list(net.parameters())[-2]	# 改成自己的层,一般可以不变
features_blobs[-1] # 根据自己的喜好改
在模型迭代过程中也可以使用,这样就可以看见模型在学哪个部分了,直接输入函数的参数就可以了
# simple implementation of CAM in PyTorch for the networks such as ResNet, DenseNet, SqueezeNet, Inception
from tqdm import tqdm
import io
import requests
from PIL import Image
import torch
from torchvision import models, transforms
from torch.autograd import Variable
from torch.nn import functional as F
import numpy as np
import cv2
import pdb
import os
import gc

def returnCAM(feature_conv, weight_softmax, class_idx, size_upsample=(384,384), use_gpu=False):
    # generate the class activation maps upsample to size_upsample
    bz, nc, h, w = feature_conv.shape
    output_cam = []
    for idx in class_idx:
        if use_gpu:
            cam = torch.mm(weight_softmax[idx].reshape(1,weight_softmax[idx].shape[0]), feature_conv.reshape((nc, h*w))).cpu().data.numpy()
        else:
            cam = torch.mm(weight_softmax[idx].reshape(1,weight_softmax[idx].shape[0]), feature_conv.reshape((nc, h*w))).data.numpy()
        cam = cam.reshape(h, w)
        cam = cam - np.min(cam)
        cam_img = cam / np.max(cam)
        cam_img = np.uint8(255 * cam_img)
        output_cam.append(cv2.resize(cam_img, size_upsample))
    return output_cam
    
def camImages(net, img_dir, out_dir, size=(384,384), classes = {0:'neg', 1:'pos'}, use_gpu=False):
    '''
    net: 自己的模型,为了方便在训练中使用cam,所以是使用的已经加载好的
    img_dir: 你想对哪个文件夹下的图片使用
    out_dir: 你想保存输出图片在哪个文件夹
    size: 跟你网络的size保持一致
    classes: 自己类别的对应关系
    use_gpu: 是否使用gpu来加速
    '''
    
    if not os.path.exists(out_dir):
        # 判断存储文件是否存在
        os.mkdir(out_dir)

    # 我使用的是 ResNeSt50,最后的特征图名字为  layer4 ,可以提前print模型来查看你自己的名字
    finalconv_name = 'layer4'
    if use_gpu:
        net = net.cuda()
    else: 
    	net = net.cpu()
    net.eval()

    features_blobs = []
    def hook_feature(module, input, output):
        features_blobs.append(output.detach().data)
    # hook函数,使用handle作为返回是为了删除它,不然会内存泄漏
    handle = net._modules.get(finalconv_name).register_forward_hook(hook_feature)

    # 一般都是这个分布
    normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
    )

    preprocess = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
    normalize
    ])

    # 开始单张图片的检测
    for _img in tqdm(os.listdir(img_dir)):
        img_path = os.path.join(img_dir, _img)  #获取图片路径
        img_pil = Image.open(img_path).convert('RGB')   #读取图片

        img_tensor = preprocess(img_pil)    #转为tensor
        img_variable = Variable(img_tensor.unsqueeze(0))
        if use_gpu:
            logit = net(img_variable.cuda()).detach().cpu()
        else:
            logit = net(img_variable).detach()
        
        params = list(net.parameters())[-2] # 这个 -2 要注意,改成自己网络的对应层数,全连接层前面那一层,一般的网络是 -2
        weight_softmax = np.squeeze(params.detach().data)        

        h_x = F.softmax(logit, dim=1).data.squeeze()
        probs, idx = h_x.sort(0, True)
        probs = probs.numpy()
        idx = idx.numpy()
        # features_blobs[-1]表示的是你之前的 finalconv_name 那一层的特征的最后的输出
        CAMs = returnCAM(features_blobs[-1], weight_softmax, [idx[0]], size_upsample=size, use_gpu=use_gpu)

        img = cv2.imread(img_path)
        height, width, _ = img.shape
        heatmap = cv2.applyColorMap(cv2.resize(CAMs[0],(width, height)), cv2.COLORMAP_JET)
        result = heatmap * 0.3 + img * 0.6
        cv2.imwrite(os.path.join(out_dir, _img), result)  
    # 这个时候删除 hook
    handle.remove()


if __name__ == "__main__":
    net_path = './checkpoints/resnest50_v7_momentum/40.pt'
    net = torch.load(net_path)
    use_gpu = False
    img_dir = './data_zf/neg'
    out_dir = './temp_img/neg'
    size = (368, 368)
    # 自己的类别对应关系
    classes = {0:'neg', 1:'pos'}

    camImages(net, img_dir, out_dir, size=size, classes=classes, use_gpu=use_gpu)
  • 12
    点赞
  • 51
    收藏
    觉得还不错? 一键收藏
  • 33
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 33
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值