Torch生成类激活图CAM

import torch
from torch.nn import functional as F
from torchvision import models, transforms
from PIL import Image
import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

# 加载经过训练的 ResNet 模型
model = models.resnet50(pretrained=True)
model.eval()

# 载入图像并进行预处理
image_path = 'airline.png'

image = Image.open(image_path).convert('RGB')
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = preprocess(image).unsqueeze(0)

# 前向传播获取特征图
with torch.no_grad():
    features = model.conv1(input_tensor)
    features = model.layer1(features)
    features = model.layer2(features)
    features = model.layer3(features)
    features = model.layer4(features)

# 获取模型的权重
weight = model.fc.weight

print(1)
# 假设 cam 和 resized_tensor 是 PyTorch 张量
# 将它们转换为 NumPy 数组

import cv2

bz, nc, h, w = features.shape
beforeDot =  features.reshape((nc, h*w))
cam = torch.matmul(weight[1], beforeDot)#404
cam = cam.reshape(h, w)

size_upsample = (256, 256)
cam = cam - torch.min(cam)
cam_img = cam / torch.max(cam)
# cam_img = torch.uint8(255 * cam_img)

# import torch
import torch.nn.functional as F
# 使用 interpolate 函数将其调整为 [224, 224]
resized_tensor = F.interpolate(cam_img.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False)
# 现在 resized_tensor 是一个大小为 [1, 1, 224, 224] 的 PyTorch 张量
# 如果需要,你可以使用 .squeeze() 方法来移除不必要的维度
output_cam = resized_tensor.squeeze()

import numpy as np
cam_np =  output_cam.detach().numpy()
# 假设 image 是你的图像数据
# cam_np = cam_np.astype(np.uint8)

resized_tensor_np = input_tensor.detach().numpy()

# 将 image 的形状调整为 (3, 224, 224)
image = resized_tensor_np.squeeze()
# 转换图像通道顺序,从 (3, 224, 224) 调整为 (224, 224, 3)

image = np.transpose(image, (1, 2, 0))


import matplotlib.pyplot as plt
# 创建一个新的图形
plt.figure(figsize=(8, 8))
# 绘制原始图像
plt.subplot(1, 2, 1)
plt.imshow(image)#, cmap='gray')
plt.title('Original Image')
# 绘制 CAM
plt.subplot(1, 2, 2)
plt.imshow(cam_np, cmap='jet')  # 使用 'jet' 颜色映射以突出 CAM
plt.title('Class Activation Map (CAM)')
# 显示图形
plt.show()

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值