本代码面向三维医学影像数据(nii格式)基于PyTorch对CAM可视化进行实现。
import numpy as np
from model import base_net
import cv2
import torch
import scipy.ndimage as ndimage
from skimage.transform import resize
from matplotlib import pyplot as plt
import nibabel
###---导入模型----###
MRI_path = 'data.nii.gz'
model_path = 'MyModel.pt'
# 读取数据
MRI = nibabel.load(MRI_path)
MRI_array = MRI.get_fdata()
MRI_array = MRI_array.astype('float32')
# data preprocess
max_value = MRI_array.max()
MRI_array = MRI_array / max_value
MRI_tensor = torch.FloatTensor(MRI_array).unsqueeze(0).unsqueeze(0)
#print('origin MRI shape: ', MRI_tensor.shape)
MRI_tensor = MRI_tensor.cuda()
grad_model = base_net().cuda()
# use register_forward_hook() to gain the features map
class LayerActivations:
features = None
def __init__(self, model, layer_num):
self.hook = model[layer_num].register_forward_hook(self.hook_fn)
# 获取model.features中某一层的output
def hook_fn(self, module, MRI_tensorut, output):
self.features = output.cpu()
def remove(self): ## remove hook
self.hook.remove()
# load model
grad_model.load_state_dict(torch.load(model_path))
grad_model.eval()
# Instantiate, get the i_th layer (second argument) of each convolution
# conv_out = LayerActivations(grad_model.Conv2.conv,0) # train
conv_out = LayerActivations(grad_model.conv8, 0) # test
output = grad_model(MRI_tensor)
cam = conv_out.features # gain the ith output
# cam = output # gain the latest output
conv_out.remove # delete the hook
###---lAYER-Name--to-visualize--###
# Create a graph that outputs target convolution and output
print('cam.shape1', cam.shape)
cam = cam.cpu().detach().numpy().squeeze()
print('cam.shape2', cam.shape)
cam = cam[1]
print('cam.shape3', cam.shape)
capi = resize(cam, (MRI_tensor.shape[2], MRI_tensor.shape[3], MRI_tensor.shape[4]))
# print(capi.shape)
capi = np.maximum(capi, 0)
heatmap = (capi - capi.min()) / (capi.max() - capi.min())
f, axarr = plt.subplots(3, 3, figsize=(12, 12))
f.suptitle('CAM_3D_medical_image', fontsize=30)
axial_slice_count = 80
coronal_slice_count = 80
sagittal_slice_count = 80
sagittal_MRI_img = np.squeeze(MRI_array[sagittal_slice_count, :, :])
sagittal_grad_cmap_img = np.squeeze(heatmap[sagittal_slice_count, :, :])
axial_MRI_img = np.squeeze(MRI_array[:, :, axial_slice_count])
axial_grad_cmap_img = np.squeeze(heatmap[:, :, axial_slice_count])
coronal_MRI_img = np.squeeze(MRI_array[:, coronal_slice_count, :])
coronal_grad_cmap_img = np.squeeze(heatmap[:, coronal_slice_count, :])
# Sagittal view
img_plot = axarr[0, 0].imshow(np.rot90(sagittal_MRI_img, 1), cmap='gray')
axarr[0, 0].axis('off')
axarr[0, 0].set_title('Sagittal MRI', fontsize=25)
img_plot = axarr[0, 1].imshow(np.rot90(sagittal_grad_cmap_img, 1), cmap='jet')
axarr[0, 1].axis('off')
axarr[0, 1].set_title('Weight-CAM', fontsize=25)
# Zoom in ten times to make the weight map smoother
sagittal_MRI_img = ndimage.zoom(sagittal_MRI_img, (1, 1), order=3)
# Overlay the weight map with the original image
sagittal_overlay = cv2.addWeighted(sagittal_MRI_img, 0.3, sagittal_grad_cmap_img, 0.6, 0)
img_plot = axarr[0, 2].imshow(np.rot90(sagittal_overlay, 1), cmap='jet')
axarr[0, 2].axis('off')
axarr[0, 2].set_title('Overlay', fontsize=25)
# Axial view
img_plot = axarr[1, 0].imshow(np.rot90(axial_MRI_img, 1), cmap='gray')
axarr[1, 0].axis('off')
axarr[1, 0].set_title('Axial MRI', fontsize=25)
img_plot = axarr[1, 1].imshow(np.rot90(axial_grad_cmap_img, 1), cmap='jet')
axarr[1, 1].axis('off')
axarr[1, 1].set_title('Weight-CAM', fontsize=25)
axial_MRI_img = ndimage.zoom(axial_MRI_img, (1, 1), order=3)
axial_overlay = cv2.addWeighted(axial_MRI_img, 0.3, axial_grad_cmap_img, 0.6, 0)
img_plot = axarr[1, 2].imshow(np.rot90(axial_overlay, 1), cmap='jet')
axarr[1, 2].axis('off')
axarr[1, 2].set_title('Overlay', fontsize=25)
# coronal view
img_plot = axarr[2, 0].imshow(np.rot90(coronal_MRI_img, 1), cmap='gray')
axarr[2, 0].axis('off')
axarr[2, 0].set_title('Coronal MRI', fontsize=50)
img_plot = axarr[2, 1].imshow(np.rot90(coronal_grad_cmap_img, 1), cmap='jet')
axarr[2, 1].axis('off')
axarr[2, 1].set_title('Weight-CAM', fontsize=50)
coronal_ct_img = ndimage.zoom(coronal_MRI_img, (1, 1), order=3)
Coronal_overlay = cv2.addWeighted(coronal_ct_img, 0.3, coronal_grad_cmap_img, 0.6, 0)
img_plot = axarr[2, 2].imshow(np.rot90(Coronal_overlay, 1), cmap='jet')
axarr[2, 2].axis('off')
axarr[2, 2].set_title('Overlay', fontsize=50)
plt.colorbar(img_plot,shrink=0.5) # color bar if need
# plt.show()
plt.savefig('CAM_demo_test.png')
结果示意图:
参考链接: