特征图可视化(pytorch)

本篇博客的可视化是可视化网络的每层特征图,不是指类激活图(CAM)可视化,CAM可视化可以参考Grad-Cam实现流程(pytorch)
这篇博客的目的仅是记录而已,由于距离上次使用过于久远,具体参考文章已经找不到,因此结尾未加入参考链接.
可视化效果如下图:

浅层
在这里插入图片描述
深层
在这里插入图片描述
代码
利用tensorboard可视化特征图,以VGG16为例.

import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
import torchvision.models as models
import torch
import torch.nn.functional as F

# ----------------------------------- feature map visualization -----------------------------------

writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")

# 数据加载及预处理
path_img = "./Forsters_Tern_0016_152463.jpg"     # your path to image
normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768]

norm_transform = transforms.Normalize(normMean, normStd)
img_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    norm_transform])
img_pil = Image.open(path_img).convert('RGB')
if img_transforms is not None:
    img_tensor = img_transforms(img_pil)
    img_tensor.unsqueeze_(0)    # chw --> bchw

# 模型加载
vggnet = models.vgg16_bn(pretrained=False)
pthfile = './pretrained/vgg16_bn-6c64b313.pth'
vggnet.load_state_dict(torch.load(pthfile))
# print(vggnet)

# 注册hook
fmap_dict = dict()
n = 0
# for name, sub_module in vggnet.named_modules():  # named_modules()返回网络的子网络层及其名称
#     if isinstance(sub_module, nn.Conv2d):
#         n += 1
#         print('Conv_'+str(n)+'_'+name)

def hook_func(m, i, o):
    # print(m)
    key_name = str(m.weight.shape)
    fmap_dict[key_name].append(o)

for name, sub_module in vggnet.named_modules():  # named_modules()返回网络的子网络层及其名称
    if isinstance(sub_module, nn.Conv2d):
        n += 1
        key_name = str(sub_module.weight.shape)
        # key_name = 'Conv_'+str(n)
        # Python 字典 setdefault() 函数和 get()方法 类似, 如果键不存在于字典中,将会添加键并将值设为默认值。
        fmap_dict.setdefault(key_name, list())
        # print(fmap_dict,'\n')

        n1, n2 = name.split(".")
        # print(n1,n2)
            # print(fmap_dict,'\n')
        # print(name)
        # print('1',vggnet._modules[n1]._modules[n2].named_modules())
        vggnet._modules[n1]._modules[n2].register_forward_hook(hook_func)

# forward
output = vggnet(img_tensor)
print(fmap_dict['torch.Size([128, 64, 3, 3])'][0].shape)
# add image
for layer_name, fmap_list in fmap_dict.items():
    fmap = fmap_list[0]
    # print(fmap.shape)
    fmap.transpose_(0, 1)
    # print(fmap.shape)

    nrow = int(np.sqrt(fmap.shape[0]))
    # if layer_name == 'torch.Size([512, 512, 3, 3])':
    fmap = F.interpolate(fmap, size=[112, 112], mode="bilinear")
    
    fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)
    print(type(fmap_grid),fmap_grid.shape)
    writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=322)

  • 16
    点赞
  • 86
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值