【脚本工具库】热力图生成(附源码)

热力图是一种可视化技术,可以通过颜色变化展示数据的强度或分布。在图像处理中,热力图常用于强调图像的某些区域,以帮助人们更直观地理解图像内容。本文将介绍两种生成热力图的方法,并提供完整的Python脚本源码。

方法一:使用预训练模型生成热力图

第一种方法是利用预训练的卷积神经网络(例如VGG19)生成热力图。这种方法可以突出图像中模型关注的区域。

from PIL import Image
import torch
import torchvision.transforms as transforms
import numpy as np
import cv2
import torchvision.models as models
import matplotlib.pyplot as plt

def draw_CAM(model, img_path, save_path, resize=224, isSave=False, isShow=False):
    img = Image.open(img_path).convert('RGB')
    loader = transforms.Compose([transforms.Resize((resize, resize)), transforms.ToTensor()])
    img = loader(img).unsqueeze(0)

    model.eval()
    feature = model.features(img)
    feature_flatten = feature.view(feature.size(0), -1)
    output = model.classifier(feature_flatten)
    pred = torch.argmax(output).item()
    pred_class = output[:, pred]

    global feature_grad

    def hook_grad(grad):
        global feature_grad
        feature_grad = grad

    feature.register_hook(hook_grad)
    pred_class.backward(retain_graph=True)

    grads = feature_grad
    pooled_grads = torch.nn.functional.adaptive_avg_pool2d(grads, (1, 1))
    pooled_grads = pooled_grads[0]
    features = feature[0]

    for i in range(len(features)):
        features[i, ...] *= pooled_grads[i, ...]

    heatmap = features.detach().numpy()
    heatmap = np.mean(heatmap, axis=0)
    heatmap = np.maximum(heatmap, 0)
    heatmap /= np.max(heatmap)

    img = cv2.imread(img_path)
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    superimposed_img = heatmap * 0.4 + img
    superimposed_img = np.clip(superimposed_img, 0, 255).astype('uint8')

    if isSave:
        cv2.imwrite(save_path, superimposed_img)

    if isShow:
        # Display the original image and the heatmap side by side
        plt.figure(figsize=(10, 4))

        plt.subplot(1, 2, 1)
        plt.title('Original Image')
        plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.title('Heatmap')
        plt.imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
        plt.axis('off')

        plt.show()

model = models.vgg19(pretrained=True)
draw_CAM(model, r'C:\Users\Heitie\Desktop\dog.jpg', '00001.jpg', isSave=True, isShow=True)
方法二:基于灰度图像生成热力图

第二种方法是将图像转换为灰度图像,然后应用热力图颜色映射。这种方法简单直接,适用于不需要复杂模型的情况。

import cv2
from matplotlib import pyplot as plt

# 加载图像
image_path = r'C:\Users\Heitie\Desktop\dog.jpg'
image = cv2.imread(image_path)

# 转换为灰度图像
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

# 应用热力图颜色映射
heatmap_img = cv2.applyColorMap(gray_image, cv2.COLORMAP_JET)

# 显示原图和热力图
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.title('Original Image')
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title('Heatmap')
plt.imshow(cv2.cvtColor(heatmap_img, cv2.COLOR_BGR2RGB))
plt.axis('off')

plt.show()
使用说明
  1. 方法一:下载预训练模型(如VGG19),并加载目标图像。脚本会生成热力图并叠加在原图上,保存结果并可视化展示。
  2. 方法二:将图像转换为灰度图,然后应用热力图颜色映射,显示结果。
总结

这两种方法可以帮助你生成图像的热力图,突出图像的显著区域。希望这些脚本对你有所帮助。如果你有任何问题或建议,欢迎在评论区留言讨论。

感谢阅读!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值