模型可视化工具 torchinfo 计算每层输出的Output Shape

本文介绍了如何使用torchinfo库方便地查看ResNet152和ViViT神经网络结构,包括参数数量、计算量和输入输出尺寸,展示了ViViT模型的详细信息。通过实例展示了如何获取模型摘要,并对比了两种模型的特性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

官方文档,pypi

之前的方法是print(model),或者在debug的时候去取值,但是在pip install torchinfo ,可以很方便的查看网络中的信息:

# 官方的例子
from torchinfo import summary
import torchvision

model = torchvision.models.resnet152()
summary(model, (1, 3, 224, 224), depth=3)
# [看看之前的vivit的效果](https://blog.csdn.net/ResumeProject/article/details/123470594?)
# 效果还行
	img = torch.ones([1, 16, 3, 64, 64]).cuda()
    # b t c (h p1) (w p2) -> b t (h w) (p1 p2 c)   p1=p2=16  # torch.Size([1, 16, 3, 64, 64]) -> torch.Size([1, 16, 16, 192])
    model = ViViT(224, 16, 100, 16).cuda()
    from torchinfo import *
    summary(
        model,                 # PyTorch model
        (1, 16, 3, 64, 64),    #  Shape of input data as a List/Tuple/torch.Size
        # dtypes=[torch.long],
        # verbose=2,
        # col_width=16,
        # col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
        # row_settings=["var_names"],
    )
====================================================================================================
Layer (type:depth-idx)                             Output Shape              Param #
====================================================================================================
ViViT                                              --                        --
├─Transformer: 1                                   --                        --
│    └─ModuleList: 2-1                             --                        --
│    │    └─ModuleList: 3-1                        --                        444,288
│    │    └─ModuleList: 3-2                        --                        444,288
│    │    └─ModuleList: 3-3                        --                        444,288
│    │    └─ModuleList: 3-4                        --                        444,288
├─Transformer: 1                                   --                        --
│    └─ModuleList: 2-2                             --                        --
│    │    └─ModuleList: 3-5                        --                        444,288
│    │    └─ModuleList: 3-6                        --                        444,288
│    │    └─ModuleList: 3-7                        --                        444,288
│    │    └─ModuleList: 3-8                        --                        444,288
├─Sequential: 1-1                                  [1, 16, 16, 192]          --
│    └─Rearrange: 2-3                              [1, 16, 16, 768]          --
│    └─Linear: 2-4                                 [1, 16, 16, 192]          147,648
├─Dropout: 1-2                                     [1, 16, 17, 192]          --
├─Transformer: 1-3                                 [16, 17, 192]             --
│    └─LayerNorm: 2-5                              [16, 17, 192]             384
├─Transformer: 1-4                                 [1, 17, 192]              --
│    └─LayerNorm: 2-6                              [1, 17, 192]              384
├─Sequential: 1-5                                  [1, 100]                  --
│    └─LayerNorm: 2-7                              [1, 192]                  384
│    └─Linear: 2-8                                 [1, 100]                  19,300
====================================================================================================
Total params: 3,722,404
Trainable params: 3,722,404
Non-trainable params: 0
Total mult-adds (M): 30.39
====================================================================================================
Input size (MB): 0.79
Forward/backward pass size (MB): 20.37
Params size (MB): 14.89
Estimated Total Size (MB): 36.05
====================================================================================================

Process finished with exit code 0

### PyTorch .pth 文件的可视化 对于 `.pth` 文件中的模型权重进行可视化的操作,可以分为几个方面来考虑。一方面是对模型架构及其各层参数量等信息的展示;另一方面则是针对特定层(比如卷积层)内部特征图或其他形式的数据分布情况进行图形化表示。 #### 使用 `torchinfo` 库查看模型结构详情 为了更好地理解整个神经网络的设计以及每一步计算过程中张量的变化情况,推荐先利用第三方库 `torchinfo` 对加载后的模型进行全面解析并打印出来: ```python from torchinfo import summary import torch from model import Model # 假设这是自定义的一个类名 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Model().to(device) model.load_state_dict(torch.load('net.pth', map_location=device)) model.eval() dummy_input = torch.randn(1, 1, 30, 30).to(device) summary(model, input_size=(1, 1, 30, 30), col_names=["input_size", "output_size", "num_params"], depth=5)[^2] ``` 这段代码会输出详细的表格描述了各个子模块的信息,包括但不限于输入/输出形状、参数数量等。 #### 特征图可视化 如果目标是在于观察某些中间层产生的激活值,则可以通过如下方式获取指定位置上的数据,并将其转换成易于绘图的形式: ```python import matplotlib.pyplot as plt import numpy as np def plot_feature_maps(feature_map_tensor): """绘制给定tensor对应的feature maps""" feature_map_array = feature_map_tensor.detach().cpu().numpy()[0] num_channels = feature_map_array.shape[0] fig, axes = plt.subplots(nrows=int(np.sqrt(num_channels)), ncols=int(np.sqrt(num_channels))) for i in range(min(len(axes.flat), num_channels)): ax = axes.flatten()[i] im = ax.imshow(feature_map_array[i], cmap='viridis') ax.axis('off') plt.tight_layout() plt.show() with torch.no_grad(): outputs = model(dummy_input) for name, module in model.named_children(): if isinstance(module, torch.nn.Conv2d): output = module(dummy_input) print(f"Visualizing {name} layer's feature maps.") plot_feature_maps(output) break # 如果只需要看第一个conv层的话 ``` 上述脚本选择了第一个遇到的二维卷积层作为例子进行了特征映射的提取与显示。当然也可以调整逻辑遍历更多类型的组件或是全部符合条件的对象。 #### TensorBoard 集成 除了直接生成静态图片外,还可以借助 TensorFlow 提供的强大工具——TensorBoard 来动态跟踪实验进展状况。这不仅限于简单的图像渲染,还包括损失函数曲线变化趋势等多种维度的表现形式。 ```python from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter(log_dir='./logs') writer.add_graph(model, dummy_input) writer.close() ``` 执行以上命令之后,在终端里启动 tensorboard (`tensorboard --logdir=./logs`) 并访问浏览器页面即可获得交互式的探索体验。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值