【特征可视化】超分辨率中的可视化特征图

该文介绍如何对基于Attention的训练模型进行测试,查看其在卷积层关注的特征。通过获取模型的特定层特征图,然后进行可视化,以理解网络在处理图像时关注的模式。文章提供了获取和展示特征图的代码,包括提取指定层的特征图以及保存和显示这些特征图。
摘要由CSDN通过智能技术生成

将基于Attention训练好的模型用测试集进行测试,并查看卷积层的特征图,看看加了注意力机制的网络“注意了”哪些特征?

  1. 获取模型中的特征子模块

首先得查看模型的有哪些特征子模块(除去池化层和激活函数层)

   model_layer= list(model.children())
  1. 如何获取卷积层中的feature maps?

#2. 获取第k层的特征图
'''
k:定义提取第几层的feature map
x:图片的tensor
model_layer:特征层
'''
def get_k_layer_feature_map(model_layer, k, x):
    with torch.no_grad():
        for index, layer in enumerate(model_layer):#model的第一个Sequential()是有多层,所以遍历
            x = layer(x)#torch.Size([1, 64, 55, 55])生成了64个通道
            if k == index:
                return x
  1. 如何可视化卷积层中的特征图?

#  可视化特征图
def show_feature_map(feature_map):#feature_map=torch.Size([1, 64, 55, 55]),feature_map[0].shape=torch.Size([64, 55, 55])
                                                                        # feature_map[2].shape     out of bounds
   feature_map = feature_map.squeeze(0)#压缩成torch.Size([64, 55, 55])
   feature_map_num = feature_map.shape[0]#返回通道数
   print("the num of all features:",feature_map_num)
   row_num = np.ceil(np.sqrt(feature_map_num))#8

   if(os.path.exists('feature_map_save') == False):
       os.mkdir('feature_map_save')

   plt.title("feature maps visualize")
   for index in range(1, feature_map_num + 1):#通过遍历的方式,将64个通道的tensor拿出

       plt.subplot(row_num, row_num, index)
       plt.imshow(feature_map[index - 1])#feature_map[0].shape=torch.Size([55, 55])
       plt.axis('off')
       plt.imsave( 'feature_map_save//'+str(index) + ".png", feature_map[index - 1])

   plt.show()
  1. 完整代码和特征图可视化

def get_image_info(image_dir):
   image_info = Image.open(image_dir).convert('RGB')
   y,cb,cr = image_info.split()
   # 数据预处理方法
   image_transform = transforms.Compose([
       #transforms.Resize(256),
       transforms.CenterCrop(224),
       transforms.ToTensor(),
       #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
   ])
   image_info = image_transform(y)#torch.Size([3, 224, 224])
   image_info = image_info.unsqueeze(0)#torch.Size([1, 3, 224, 224])因为model的输入要求是4维,所以变成4维
   return image_info#变成tensor数据

#2. 获取第k层的特征图
'''
k:定义提取第几层的feature map
x:图片的tensor
model_layer:特征层
'''
def get_k_layer_feature_map(model_layer, k, x):
   with torch.no_grad():
       for index, layer in enumerate(model_layer):#model的第一个Sequential()是有多层,所以遍历
           x = layer(x)#torch.Size([1, 64, 55, 55])生成了64个通道
           if k == index:
               return x

#  可视化特征图
def show_feature_map(feature_map):#feature_map=torch.Size([1, 64, 55, 55]),feature_map[0].shape=torch.Size([64, 55, 55])
                                                                        # feature_map[2].shape     out of bounds
   feature_map = feature_map.squeeze(0)#压缩成torch.Size([64, 55, 55])
   feature_map_num = feature_map.shape[0]#返回通道数
   print("the num of all features:",feature_map_num)
   row_num = np.ceil(np.sqrt(feature_map_num))#8

   if(os.path.exists('feature_map_save') == False):
       os.mkdir('feature_map_save')

   plt.title("feature maps visualize")
   for index in range(1, feature_map_num + 1):#通过遍历的方式,将64个通道的tensor拿出

       plt.subplot(row_num, row_num, index)
       plt.imshow(feature_map[index - 1])#feature_map[0].shape=torch.Size([55, 55])
       plt.axis('off')
       plt.imsave( 'feature_map_save//'+str(index) + ".png", feature_map[index - 1])
   plt.show()

if __name__ ==  '__main__':
   image_dir = "./bird.bmp"
   # 定义提取第几层的feature map
   k = 0
   image_info = get_image_info(image_dir)
   model = ESPCN(upscale_factor=4)
   model.load_state_dict('...your own model state_dict path...',map_location='cpu'))
   model_layer= list(model.children() )
   print("model childen part :",model_layer)
   feature_map = get_k_layer_feature_map(model_layer, 1, image_info)
   show_feature_map(feature_map)
  1. 可视化64个通道的特征

节选自以下博客:

基于深度学习使用Attention机制的图像超分辨率模型_梦中的伊犁河谷的博客-CSDN博客

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值