基于三维医学影像数据的CAM可视化——PyTorch实现

本代码面向三维医学影像数据(nii格式)基于PyTorch对CAM可视化进行实现。

import numpy as np
from model import base_net
import cv2
import torch
import scipy.ndimage as ndimage
from skimage.transform import resize
from matplotlib import pyplot as plt
import nibabel

###---导入模型----###
MRI_path = 'data.nii.gz'
model_path = 'MyModel.pt'
# 读取数据
MRI = nibabel.load(MRI_path)
MRI_array = MRI.get_fdata()
MRI_array = MRI_array.astype('float32')

# data preprocess
max_value = MRI_array.max()
MRI_array = MRI_array / max_value
MRI_tensor = torch.FloatTensor(MRI_array).unsqueeze(0).unsqueeze(0)
#print('origin MRI shape: ', MRI_tensor.shape)

MRI_tensor = MRI_tensor.cuda()
grad_model = base_net().cuda()

# use register_forward_hook() to gain the features map
class LayerActivations:
    features = None

    def __init__(self, model, layer_num):
        self.hook = model[layer_num].register_forward_hook(self.hook_fn)
        # 获取model.features中某一层的output

    def hook_fn(self, module, MRI_tensorut, output):
        self.features = output.cpu()

    def remove(self):  ## remove hook
        self.hook.remove()

# load model
grad_model.load_state_dict(torch.load(model_path))
grad_model.eval()

# Instantiate, get the i_th layer (second argument) of each convolution
# conv_out = LayerActivations(grad_model.Conv2.conv,0) # train
conv_out = LayerActivations(grad_model.conv8, 0)  # test

output = grad_model(MRI_tensor)
cam = conv_out.features  # gain the ith output
# cam = output # gain the latest output
conv_out.remove  # delete the hook

###---lAYER-Name--to-visualize--###
# Create a graph that outputs target convolution and output
print('cam.shape1', cam.shape)
cam = cam.cpu().detach().numpy().squeeze()
print('cam.shape2', cam.shape)
cam = cam[1]
print('cam.shape3', cam.shape)

capi = resize(cam, (MRI_tensor.shape[2], MRI_tensor.shape[3], MRI_tensor.shape[4]))
# print(capi.shape)
capi = np.maximum(capi, 0)
heatmap = (capi - capi.min()) / (capi.max() - capi.min())
f, axarr = plt.subplots(3, 3, figsize=(12, 12))

f.suptitle('CAM_3D_medical_image', fontsize=30)

axial_slice_count = 80
coronal_slice_count = 80
sagittal_slice_count = 80

sagittal_MRI_img = np.squeeze(MRI_array[sagittal_slice_count, :, :])
sagittal_grad_cmap_img = np.squeeze(heatmap[sagittal_slice_count, :, :])

axial_MRI_img = np.squeeze(MRI_array[:, :, axial_slice_count])
axial_grad_cmap_img = np.squeeze(heatmap[:, :, axial_slice_count])

coronal_MRI_img = np.squeeze(MRI_array[:, coronal_slice_count, :])
coronal_grad_cmap_img = np.squeeze(heatmap[:, coronal_slice_count, :])

# Sagittal view
img_plot = axarr[0, 0].imshow(np.rot90(sagittal_MRI_img, 1), cmap='gray')
axarr[0, 0].axis('off')
axarr[0, 0].set_title('Sagittal MRI', fontsize=25)

img_plot = axarr[0, 1].imshow(np.rot90(sagittal_grad_cmap_img, 1), cmap='jet')
axarr[0, 1].axis('off')
axarr[0, 1].set_title('Weight-CAM', fontsize=25)

# Zoom in ten times to make the weight map smoother
sagittal_MRI_img = ndimage.zoom(sagittal_MRI_img, (1, 1), order=3)
# Overlay the weight map with the original image
sagittal_overlay = cv2.addWeighted(sagittal_MRI_img, 0.3, sagittal_grad_cmap_img, 0.6, 0)

img_plot = axarr[0, 2].imshow(np.rot90(sagittal_overlay, 1), cmap='jet')
axarr[0, 2].axis('off')
axarr[0, 2].set_title('Overlay', fontsize=25)

# Axial view
img_plot = axarr[1, 0].imshow(np.rot90(axial_MRI_img, 1), cmap='gray')
axarr[1, 0].axis('off')
axarr[1, 0].set_title('Axial MRI', fontsize=25)

img_plot = axarr[1, 1].imshow(np.rot90(axial_grad_cmap_img, 1), cmap='jet')
axarr[1, 1].axis('off')
axarr[1, 1].set_title('Weight-CAM', fontsize=25)

axial_MRI_img = ndimage.zoom(axial_MRI_img, (1, 1), order=3)
axial_overlay = cv2.addWeighted(axial_MRI_img, 0.3, axial_grad_cmap_img, 0.6, 0)

img_plot = axarr[1, 2].imshow(np.rot90(axial_overlay, 1), cmap='jet')
axarr[1, 2].axis('off')
axarr[1, 2].set_title('Overlay', fontsize=25)

# coronal view
img_plot = axarr[2, 0].imshow(np.rot90(coronal_MRI_img, 1), cmap='gray')
axarr[2, 0].axis('off')
axarr[2, 0].set_title('Coronal MRI', fontsize=50)

img_plot = axarr[2, 1].imshow(np.rot90(coronal_grad_cmap_img, 1), cmap='jet')
axarr[2, 1].axis('off')
axarr[2, 1].set_title('Weight-CAM', fontsize=50)

coronal_ct_img = ndimage.zoom(coronal_MRI_img, (1, 1), order=3)
Coronal_overlay = cv2.addWeighted(coronal_ct_img, 0.3, coronal_grad_cmap_img, 0.6, 0)

img_plot = axarr[2, 2].imshow(np.rot90(Coronal_overlay, 1), cmap='jet')
axarr[2, 2].axis('off')
axarr[2, 2].set_title('Overlay', fontsize=50)

plt.colorbar(img_plot,shrink=0.5) # color bar if need
# plt.show()
plt.savefig('CAM_demo_test.png')

结果示意图:

参考链接:

GitHub - KennanYang/CAM-demo-for-3d-medical-image: Its a CAM(Class Activation Mapping) demo for 3d medical image. (pytorch and UNet 3d)icon-default.png?t=MBR7https://github.com/KennanYang/CAM-demo-for-3d-medical-image 

  • 6
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 11
    评论
以下是一个简单的 PyTorch 代码示例,用于生成 CNN 回归模型 Grad-CAM 可视化: ```python import torch from torch.autograd import Variable import torch.nn.functional as F import cv2 import numpy as np class GradCAM: def __init__(self, model, target_layer): self.model = model self.target_layer = target_layer self.feature_maps = None self.gradients = None def forward(self, x): for name, module in self.model.named_children(): x = module(x) if name == self.target_layer: self.feature_maps = x return x def backward(self, target_class): one_hot = torch.zeros(self.outputs.size()) one_hot[0][target_class] = 1 self.model.zero_grad() self.outputs.backward(gradient=one_hot, retain_graph=True) self.gradients = self.target_layer.grad.cpu().data.numpy()[0] def generate(self, x, target_class): x = x.unsqueeze(0) x = Variable(x, requires_grad=True) self.forward(x) self.backward(target_class) weights = np.mean(self.gradients, axis=(1,2)) cam = np.zeros(self.feature_maps.shape[1:], dtype=np.float32) for i, w in enumerate(weights): cam += w * self.feature_maps[i, :, :] cam = np.maximum(cam, 0) cam = cv2.resize(cam, x.shape[2:]) cam = cam - np.min(cam) cam = cam / np.max(cam) return cam ``` 使用时,需要将目标 CNN 模型和目标层名称传递给 GradCAM 类的构造函数。然后,可以使用 forward() 方法计算特征映射,使用 backward() 方法计算梯度,最后使用 generate() 方法生成 Grad-CAM 可视化。 示例代码中使用的目标 CNN 模型是一个简单的分类模型,因此需要根据实际情况进行修改。此外,还需要根据实际情况调整输入图像的大小和目标类别。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值