主要是借助tensorboardX中的writer.add_image和torchvision.utils中的make_grid来生成的。
主要代码:
extract_result[i]=extract_result[i].permute(1,0,2,3).cpu() #1,c,h,w--->c,1,h,w
extract_result[i]=make_grid(extract_result[i]) #concat the images
writer.add_image('step'+str(i),extract_result[i]) #add image to tensorboard
对于要提取的特征层,由于模型不同可能不太好提取,建议直接在模型的forward的函数里面改,加入一个标志位,使forward根据不同的情况输出不同的类型参考代码:
def forward(self,x,vs_feature=False):
if vs_feature:
features=[]
#encode part
x = self.resnet.conv1(x)
x = self.resnet.bn1(x)
x = self.resnet.relu(x)
x = self.resnet.maxpool(x)
x1 = self.resnet.layer1(x) #x1 :b,64,88,88
x2 = self.resnet.layer2(x1) #x2 :b,128,44,44
x3 = self.resnet.layer3(x2) #x3 :b,256,22,22
x4 = self.resnet.layer4(x3) #x4 :b,512,11,11
x1_rb=self.RFB_1(x1) #x1_rb :b,32,88,88
x2_rb=self.RFB_2(x2) #x2_rb :b,64,44,44
x3_rb=self.RFB_3(x3) #x3_rb :b,128,22,22
x4_rb=self.RFB_4(x4) #x1_rb :b,256,11,11
# decoding + concat path U-net part
d4=self.up_4(x4_rb) #d4 :b,128,22,22
d4 = torch.cat((x3_rb,d4),dim=1) #d4 :b,256,22,22
d4=self.up_conv4(d4) #d4 :b,128,22,22
d3=self.up_3(d4) #d3 :b,64,44,44
d3 = torch.cat((x2_rb,d3),dim=1) #d3 :b,128,44,44
d3=self.up_conv3(d3) #d3 :b,64,44,44
d2=self.up_2(d3) #d2 :b,32,88,88
d2 = torch.cat((x1_rb,d2),dim=1) #d2 :b,64,88,88
d2=self.up_conv2(d2) #d2 :b,32,88,88
d1_output=self.output_conv_1(d2) #d1_output:b,1,88,88
d2_output=self.output_conv_2(d3) #d1_output:b,1,44,44
d3_output=self.output_conv_3(d4) #d1_output:b,1,22,22
d1_output=F.upsample(d1_output,scale_factor=4, mode='bilinear', align_corners=True)
# d2_output=F.upsample(d2_output,scale_factor=8, mode='bilinear', align_corners=True)
# d3_output=F.upsample(d3_output,scale_factor=16, mode='bilinear', align_corners=True)
if vs_feature:
features.extend([x1,x2,x3,x4,x1_rb,x2_rb,x3_rb,x4_rb,d4,d3,d2,d1_output,d2_output,d3_output])
return d1_output.sigmoid(),features #return features which you want to visualization
return d1_output.sigmoid(),d2_output.sigmoid(),d3_output.sigmoid()
完整的借助tensorboardX可视化的代码:
import os
import torch
import torch.nn as nn
from my_model import Scotasap_model
import tensorboardX as tbX
import argparse
import cv2
from torchvision import transforms
from torchvision.utils import make_grid
from PIL import Image
from scipy import misc
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size',type=int,default=2,help='batch_size')
parser.add_argument('--epoch',type=int,default=10,help='训练总轮数')
parser.add_argument('--train_dir',type=str,default='./train_dataset/',help='训练集')
parser.add_argument('--test_dir',type=str,default='./test_dataset/',help='测试集')
parser.add_argument('--lr',type=float,default='0.0001',help='学习率')
parser.add_argument('--model_dir',type=str,default='./trained_model/',help='模型存放位置')
parser.add_argument('--size',type=int,default=352,help='图像处理大小')
parser.add_argument('--max_length',type=int,default=90,help='最长序列的大小,不足这个长度全部补齐')
parser.add_argument('--process_image_size',type=int,default=60,help='经过神经网络处理后的图像的大小')
parser.add_argument('--grad_clip',type=float,default=5,help='梯度限幅')
parser.add_argument('--total_step',type=int,default=300,help='一个epoch中训练的步数')
parser.add_argument('--image_save_dir',type=str,default='./result_image_save/',help='照片结果储存路径')
parser.add_argument('--lr_decay_rate',type=float,default=0.3,help='学习率衰减率')
parser.add_argument('--lr_decay_epoch',type=int,default=2,help='多少个epcoh不变开始改变学习率')
parser.add_argument('--save_epoch',type=int,default=5,help='每训练几个epoch储存一下模型')
parser.add_argument('--checkpoint_savedir',type=str,default='./checkpoint/',help='储存模型checkpoint')
parser.add_argument('--channels',type=int,default=32,help='the main numbers of channel in the model')
args = parser.parse_args()
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
writer=tbX.SummaryWriter('runs/exp2')
model=Scotasap_model(args).to(device)
model.load_state_dict(torch.load(os.path.join(args.model_dir,'parnet_best_test_value.pth')))
pic_dir='COD10K-CAM-1-Aquatic-1-BatFish-6.jpg'
image=cv2.imread(pic_dir,cv2.IMREAD_COLOR)
image=cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
image = Image.fromarray(image)
image_process=transforms.Compose([
transforms.Resize(size=(args.size,args.size)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5)) #归一化图像,可选
])
image=image_process(image).unsqueeze(dim=0) #1,c,h,w
print(image.shape)
image=image.to(device)
model.eval()
with torch.no_grad():
pred,extract_result = model(image,vs_feature=True)
pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
pred = pred.squeeze().cpu().numpy()
for i in range(len(extract_result)):
extract_result[i]=extract_result[i].permute(1,0,2,3).cpu() #1,c,h,w--->c,1,h,w
extract_result[i]=make_grid(extract_result[i])
# print(extract_result[i].shape)
writer.add_image('step'+str(i),extract_result[i])
misc.imsave('pred-'+pic_dir,pred)