首先,将网络模型中间特征图进行输出,保存为npy格式,(vis_and_save_heatmap函数之前是计算Dice的函数)
import numpy as np
import torch.optim
from Load_Dataset import ValGenerator, ImageToImage2D
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings("ignore")
import Config as config
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from utils import *
import cv2
import pandas as pd
from networks.UNet import UNet
def show_image_with_dice(predict_save, labs, save_path):
tmp_lbl = (labs).astype(np.float32)
tmp_3dunet = (predict_save).astype(np.float32)
dice_pred = 2 * np.sum(tmp_lbl * tmp_3dunet) / (np.sum(tmp_lbl) + np.sum(tmp_3dunet) + 1e-5)
iou_pred = jaccard_score(tmp_lbl.reshape(-1),tmp_3dunet.reshape(-1))
if config.task_name is "MoNuSeg":
predict_save = cv2.pyrUp(predict_save,(448,448))
predict_save = cv2.resize(predict_save,(2000,2000))
cv2.imwrite(save_path,predict_save * 255)
else:
cv2.imwrite(save_path,predict_save * 255)
return dice_pred, iou_pred
def vis_and_save_heatmap(model, input_img, img_RGB, labs, vis_save_path, dice_pred, dice_ens):
model.eval()
a, b, c, d = model(input_img.cuda())
a, b, c, d= a.cpu().detach().numpy(), b.cpu().detach().numpy(), c.cpu().detach().numpy(), d.cpu().detach().numpy()
a=np.save('./visual/a.npy',a)
b=np.save('./visual/b.npy',b)
c=np.save('./visual/c.npy',c)
d=np.save('./visual/d.npy',d)
return a,b,c,d
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
test_session = config.test_session
if config.task_name is "Lung":
test_num = 14
model_type = config.model_name
model_path = "./MoNuSeg/"+model_type+"/"+test_session+"/models/best_model-"+model_type+".pth.tar"
save_path = config.task_name +'/'+ model_type +'/' + test_session + '/'
vis_path = "./" + model_type + config.task_name + '_visualize_test/'
if not os.path.exists(vis_path):
os.makedirs(vis_path)
checkpoint = torch.load(model_path, map_location='cuda')
if model_type == 'UNet':
config_vit = config.get_CTranS_config()
model = UNet(config_vit,n_channels=config.n_channels,n_classes=config.n_labels)
else: raise TypeError('Please enter a valid name for the model type')
model = model.cuda()
if torch.cuda.device_count() > 1:
print ("Let's use {0} GPUs!".format(torch.cuda.device_count()))
model = nn.DataParallel(model, device_ids=[0,1,2,3])
model.load_state_dict(checkpoint['state_dict'])
print('Model loaded !')
tf_test = ValGenerator(output_size=[config.img_size, config.img_size])
test_dataset = ImageToImage2D(config.test_dataset, tf_test,image_size=config.img_size)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
dice_pred = 0.0
iou_pred = 0.0
dice_ens = 0.0
logs = pd.DataFrame(index=[], columns=['dice_pred_t', 'iou_pred_t'])
with tqdm(total=test_num, desc='Test visualize', unit='img', ncols=70, leave=True) as pbar:
for i, (sampled_batch, names) in enumerate(test_loader, 1):
image_name = os.path.splitext(names[0])[0]
test_data, test_label = sampled_batch['image'], sampled_batch['label']
arr=test_data.numpy()
arr = arr.astype(np.float32())
lab=test_label.data.numpy()
img_lab = np.reshape(lab, (lab.shape[1], lab.shape[2])) * 255
fig, ax = plt.subplots()
plt.imshow(img_lab, cmap='gray')
plt.axis("off")
height, width = config.img_size, config.img_size
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
plt.margins(0, 0)
plt.savefig(vis_path+str(i)+"_"+image_name+"_lab.jpg", dpi=300)
plt.close()
input_img = torch.from_numpy(arr)
a,b,c,d =vis_and_save_heatmap(model,input_img,None, lab,
vis_path+str(i),
dice_pred=dice_pred, dice_ens=dice_ens)
torch.cuda.empty_cache()
pbar.update()
第二步,运行visual.py文件,文件中代码如下:
import numpy as np
import mmcv
from mmengine.visualization import Visualizer
import torch
import matplotlib.pyplot as plt
visualizer = Visualizer()
all_dataset_generated_attention_score_maps = np.load('./visual/d.npy')
all_dataset_generated_attention_score_maps = np.squeeze(all_dataset_generated_attention_score_maps)
visualizer.show(visualizer.draw_featmap(torch.from_numpy(all_dataset_generated_attention_score_maps), channel_reduction='squeeze_mean'))
plt.savefig('./visual/d.png')
最后就可以得到下面的图了