PyTorch可视化

(一)网络结构可视化

        使用torchinfo开源工具包进行结构化输出,只需要使用torchinfo.summary()就行。但需要注意的是可视化网络结构需要进行一次前向传播以获得特定层的信息。

import torchvision
from torchinfo import summary
resnet18=torchvision.models.resnet18()
summary(resnet18,(batchize=64,channels=3,high=224,weight=225))#进行一次前向传播以获得信息

(二)卷积CNN可视化

1.CNN卷积核可视化

        可视化卷积核是为了看模型提取哪些特征。

conv1 = dict(model.features.named_children())['conv1'] #以第1层卷积为例,可视化对应的参数
kernel_set = conv1.weight.detach()
num = len(conv1.weight.detach())
print(kernel_set.shape)
for i in range(0,num):
    i_kernel = kernel_set[i]
    plt.figure(figsize=(20, 17))
    if (len(i_kernel)) > 1:
        for idx, filer in enumerate(i_kernel):
            plt.subplot(9, 9, idx+1) 
            plt.axis('off')
            plt.imshow(filer[ :, :].detach(),cmap='bwr')

2.CNN特征图可视化

        特征图:输入的原始图像经过每次卷积层得到的数据。可视化特征图是为了看模型提取到的特征是什么样子的。

        这一般会使用PyTorch中的hook接口,相当于数据进行前向传播过程中的特征图会被hook捕捉,前向传播之后可以另行查看。一般 hook除了可以使用自己定义的之外,还可以使用PyTorch自带的nn.Module.register_forward_hook()等方式。

class Hook(object):
    def __init__(self):
        self.module_name = []
        self.features_in_hook = []
        self.features_out_hook = []

    def __call__(self,module, fea_in, fea_out):    #module是为了知道获得的是哪一层的数据
        print("hooker working", self)
        self.module_name.append(module.__class__)
        self.features_in_hook.append(fea_in)
        self.features_out_hook.append(fea_out)
        return None
    

def plot_feature(model, idx, inputs):
    hh = Hook()
    model.features[i].register_forward_hook(hh) #利用PyTorch自带的模型的层的属性
                                                 #register_forward_hook(),添加钩子。
    
    # forward_model(model,False)
    model.eval()
    _ = model(inputs)

#打印得到的信息
    print(hh.module_name)
    print((hh.features_in_hook[0][0].shape))
    print((hh.features_out_hook[0].shape))
    
    out1 = hh.features_out_hook[0]

    total_ft  = out1.shape[1]
    first_item = out1[0].cpu().clone()    

    plt.figure(figsize=(20, 17))
    

    for ftidx in range(total_ft):
        if ftidx > 99:
            break
        ft = first_item[ftidx]
        plt.subplot(10, 10, ftidx+1) 
        
        plt.axis('off')
        #plt.imshow(ft[ :, :].detach(),cmap='gray')
        plt.imshow(ft[ :, :].detach())

(三)使用TensorBoard可视化

#安装tensorboardX
from tensorboardX import SummaryWriter
#如果使用PyTorch自带的tensorboard,则下述引用
from torch.utils.tensorboard import SummaryWriter

#启动,命令行输入
tensorboard --logdir=/path/to/logs/ --port=xxxx
#为了tensorboard能够不断地在后台运行,也可以使用nohup命令或者tmux工具来运行tensorboard


#调用:
writer = SummaryWriter('./runs')#实例化SummaryWritter为变量writer,并指定writer的输出目录为当前目录下的"runs"目录
writer.add_graph(model, input_to_model = torch.rand(1, 3, 224, 224))#给定一个输入数据,前向传播后得到模型的结构,再通过TensorBoard进行可视化
writer.close()


#连续变量可视化
writer = SummaryWriter('./tb_logs')

for i in range(500):   #放一张图
    x = i
    writer.add_image("x", x, i) #日志中记录x在第step i 的值
writer.close()

#实际使用
writer = SummaryWriter('./logs')
writer.add_image(tudui,input)
writer.close()


#参数分布可视化:

for (key,val) in zip(resnet18.state_dict().keys(),resnet18.state_dict().values()):
    print(key,val)



writer = SummaryWriter('./tb_logs/')
for step, (key,val) in enumerate(zip(resnet18.state_dict().keys(),resnet18.state_dict().values()):
    writer.add_histogram("w", val, step)
    writer.flush()
writer.close()

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值