用Visualizer以热力图的方式表现中间特征图

首先,将网络模型中间特征图进行输出,保存为npy格式,(vis_and_save_heatmap函数之前是计算Dice的函数)

import numpy as np
import torch.optim
from Load_Dataset import ValGenerator, ImageToImage2D
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings("ignore")
import Config as config
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from utils import *
import cv2
import pandas as pd

from networks.UNet import UNet

def show_image_with_dice(predict_save, labs, save_path):
    tmp_lbl = (labs).astype(np.float32)
    tmp_3dunet = (predict_save).astype(np.float32)
    dice_pred = 2 * np.sum(tmp_lbl * tmp_3dunet) / (np.sum(tmp_lbl) + np.sum(tmp_3dunet) + 1e-5)
    iou_pred = jaccard_score(tmp_lbl.reshape(-1),tmp_3dunet.reshape(-1))
     
    if config.task_name is "MoNuSeg":
        predict_save = cv2.pyrUp(predict_save,(448,448))
        predict_save = cv2.resize(predict_save,(2000,2000))
        cv2.imwrite(save_path,predict_save * 255)
    else:
        cv2.imwrite(save_path,predict_save * 255)
     
    return dice_pred, iou_pred

def vis_and_save_heatmap(model, input_img, img_RGB, labs, vis_save_path, dice_pred, dice_ens):
    model.eval()

    ##假设网络模型的输出分别是第一层编码器的输出,第二层编码器的输出,第三层编码器的输出,第四层编码器的输出
    ####一次性保存四张图片
    a, b, c, d = model(input_img.cuda())
    a, b, c, d= a.cpu().detach().numpy(), b.cpu().detach().numpy(), c.cpu().detach().numpy(), d.cpu().detach().numpy()
    a=np.save('./visual/a.npy',a)
    b=np.save('./visual/b.npy',b)
    c=np.save('./visual/c.npy',c)
    d=np.save('./visual/d.npy',d)
    return a,b,c,d 

if __name__ == '__main__':
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"
    test_session = config.test_session
  
    if config.task_name is "Lung":
        test_num = 14
        model_type = config.model_name
        model_path = "./MoNuSeg/"+model_type+"/"+test_session+"/models/best_model-"+model_type+".pth.tar"

    save_path  = config.task_name +'/'+ model_type +'/' + test_session + '/'
    vis_path = "./" + model_type + config.task_name + '_visualize_test/'
    if not os.path.exists(vis_path):
        os.makedirs(vis_path)

    checkpoint = torch.load(model_path, map_location='cuda')

    if model_type == 'UNet':
        config_vit = config.get_CTranS_config()
        model = UNet(config_vit,n_channels=config.n_channels,n_classes=config.n_labels)
    
    else: raise TypeError('Please enter a valid name for the model type')

    model = model.cuda()
    if torch.cuda.device_count() > 1:
        print ("Let's use {0} GPUs!".format(torch.cuda.device_count()))
        model = nn.DataParallel(model, device_ids=[0,1,2,3])
    model.load_state_dict(checkpoint['state_dict'])
    print('Model loaded !')
    tf_test = ValGenerator(output_size=[config.img_size, config.img_size])
    test_dataset = ImageToImage2D(config.test_dataset, tf_test,image_size=config.img_size)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    dice_pred = 0.0
    iou_pred = 0.0
    dice_ens = 0.0

    logs = pd.DataFrame(index=[], columns=['dice_pred_t', 'iou_pred_t'])
    with tqdm(total=test_num, desc='Test visualize', unit='img', ncols=70, leave=True) as pbar:
        for i, (sampled_batch, names) in enumerate(test_loader, 1):
             
            image_name = os.path.splitext(names[0])[0] 

            test_data, test_label = sampled_batch['image'], sampled_batch['label']
            arr=test_data.numpy()
            arr = arr.astype(np.float32())
            lab=test_label.data.numpy()
            img_lab = np.reshape(lab, (lab.shape[1], lab.shape[2])) * 255
            fig, ax = plt.subplots()
            plt.imshow(img_lab, cmap='gray')
            plt.axis("off")
            height, width = config.img_size, config.img_size
              plt.gca().xaxis.set_major_locator(plt.NullLocator())
            plt.gca().yaxis.set_major_locator(plt.NullLocator())
            plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
            plt.margins(0, 0)
            plt.savefig(vis_path+str(i)+"_"+image_name+"_lab.jpg", dpi=300)   
             plt.close()
            input_img = torch.from_numpy(arr)

            a,b,c,d =vis_and_save_heatmap(model,input_img,None, lab,
                                                          vis_path+str(i),
                                               dice_pred=dice_pred, dice_ens=dice_ens)

            
            torch.cuda.empty_cache()
            pbar.update()   

第二步,运行visual.py文件,文件中代码如下:

# #####visualizer显示热力图
import numpy as np
import mmcv
from mmengine.visualization import Visualizer
import torch
import matplotlib.pyplot as plt
visualizer = Visualizer()
all_dataset_generated_attention_score_maps = np.load('./visual/d.npy')
all_dataset_generated_attention_score_maps = np.squeeze(all_dataset_generated_attention_score_maps)
visualizer.show(visualizer.draw_featmap(torch.from_numpy(all_dataset_generated_attention_score_maps), channel_reduction='squeeze_mean'))

plt.savefig('./visual/d.png')#保存图片

最后就可以得到下面的图了

在这里插入图片描述

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值