pytorch tensorboardX 可视化特征图(多通道)

主要是借助tensorboardX中的writer.add_image和torchvision.utils中的make_grid来生成的。
主要代码:

  extract_result[i]=extract_result[i].permute(1,0,2,3).cpu()   #1,c,h,w--->c,1,h,w
        extract_result[i]=make_grid(extract_result[i])   #concat the images 
        writer.add_image('step'+str(i),extract_result[i])  #add image to tensorboard 

对于要提取的特征层,由于模型不同可能不太好提取,建议直接在模型的forward的函数里面改,加入一个标志位,使forward根据不同的情况输出不同的类型参考代码:

def forward(self,x,vs_feature=False):

        if vs_feature:
            features=[]
        #encode part
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x1 = self.resnet.layer1(x)      #x1 :b,64,88,88
        x2 = self.resnet.layer2(x1)     #x2 :b,128,44,44          
        x3 = self.resnet.layer3(x2)     #x3 :b,256,22,22
        x4 = self.resnet.layer4(x3)     #x4 :b,512,11,11 

        x1_rb=self.RFB_1(x1)   #x1_rb :b,32,88,88
        x2_rb=self.RFB_2(x2)   #x2_rb :b,64,44,44
        x3_rb=self.RFB_3(x3)   #x3_rb :b,128,22,22
        x4_rb=self.RFB_4(x4)   #x1_rb :b,256,11,11
        
        # decoding + concat path   U-net part
        d4=self.up_4(x4_rb)      #d4 :b,128,22,22
        d4 = torch.cat((x3_rb,d4),dim=1)  #d4 :b,256,22,22
        d4=self.up_conv4(d4) #d4 :b,128,22,22

        d3=self.up_3(d4)      #d3 :b,64,44,44
        d3 = torch.cat((x2_rb,d3),dim=1)  #d3 :b,128,44,44
        d3=self.up_conv3(d3) #d3 :b,64,44,44

        d2=self.up_2(d3)      #d2 :b,32,88,88
        d2 = torch.cat((x1_rb,d2),dim=1)  #d2 :b,64,88,88
        d2=self.up_conv2(d2) #d2 :b,32,88,88

        d1_output=self.output_conv_1(d2)  #d1_output:b,1,88,88
        d2_output=self.output_conv_2(d3)  #d1_output:b,1,44,44
        d3_output=self.output_conv_3(d4)  #d1_output:b,1,22,22

        d1_output=F.upsample(d1_output,scale_factor=4, mode='bilinear', align_corners=True)
        # d2_output=F.upsample(d2_output,scale_factor=8, mode='bilinear', align_corners=True)
        # d3_output=F.upsample(d3_output,scale_factor=16, mode='bilinear', align_corners=True)


        if vs_feature:
            features.extend([x1,x2,x3,x4,x1_rb,x2_rb,x3_rb,x4_rb,d4,d3,d2,d1_output,d2_output,d3_output])
            return d1_output.sigmoid(),features    #return features which you want to visualization

        return d1_output.sigmoid(),d2_output.sigmoid(),d3_output.sigmoid()  

完整的借助tensorboardX可视化的代码:

import os
import torch
import torch.nn as nn
from my_model import Scotasap_model
import tensorboardX as tbX
import argparse
import cv2
from torchvision import transforms
from torchvision.utils import make_grid
from PIL import Image 
from scipy import misc

parser = argparse.ArgumentParser()
parser.add_argument('--batch_size',type=int,default=2,help='batch_size')   
parser.add_argument('--epoch',type=int,default=10,help='训练总轮数')
parser.add_argument('--train_dir',type=str,default='./train_dataset/',help='训练集')
parser.add_argument('--test_dir',type=str,default='./test_dataset/',help='测试集')
parser.add_argument('--lr',type=float,default='0.0001',help='学习率')
parser.add_argument('--model_dir',type=str,default='./trained_model/',help='模型存放位置')
parser.add_argument('--size',type=int,default=352,help='图像处理大小')
parser.add_argument('--max_length',type=int,default=90,help='最长序列的大小,不足这个长度全部补齐')
parser.add_argument('--process_image_size',type=int,default=60,help='经过神经网络处理后的图像的大小')
parser.add_argument('--grad_clip',type=float,default=5,help='梯度限幅')
parser.add_argument('--total_step',type=int,default=300,help='一个epoch中训练的步数')
parser.add_argument('--image_save_dir',type=str,default='./result_image_save/',help='照片结果储存路径')
parser.add_argument('--lr_decay_rate',type=float,default=0.3,help='学习率衰减率')
parser.add_argument('--lr_decay_epoch',type=int,default=2,help='多少个epcoh不变开始改变学习率')
parser.add_argument('--save_epoch',type=int,default=5,help='每训练几个epoch储存一下模型')
parser.add_argument('--checkpoint_savedir',type=str,default='./checkpoint/',help='储存模型checkpoint')
parser.add_argument('--channels',type=int,default=32,help='the main numbers of channel in the model')
args = parser.parse_args()

device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

writer=tbX.SummaryWriter('runs/exp2')
model=Scotasap_model(args).to(device)
model.load_state_dict(torch.load(os.path.join(args.model_dir,'parnet_best_test_value.pth')))

pic_dir='COD10K-CAM-1-Aquatic-1-BatFish-6.jpg'
image=cv2.imread(pic_dir,cv2.IMREAD_COLOR)
image=cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
image = Image.fromarray(image)

image_process=transforms.Compose([
            transforms.Resize(size=(args.size,args.size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))   #归一化图像,可选
            ])
image=image_process(image).unsqueeze(dim=0)   #1,c,h,w
print(image.shape)
image=image.to(device)


model.eval()
with torch.no_grad():
 
    pred,extract_result = model(image,vs_feature=True)
    pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
    pred = pred.squeeze().cpu().numpy()

    for i in range(len(extract_result)):
        extract_result[i]=extract_result[i].permute(1,0,2,3).cpu()   #1,c,h,w--->c,1,h,w
        extract_result[i]=make_grid(extract_result[i])
        # print(extract_result[i].shape)
        writer.add_image('step'+str(i),extract_result[i])
    
    misc.imsave('pred-'+pic_dir,pred)
  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值