pytorch模型可视化的方法总结,(参数量内存,特征图内存),FLOPs和Parameters

本文介绍了如何使用torchsummary、torchviz、netron、tensorwatch和get_model_complexity_info等工具对PyTorch的SRCNN模型进行网络结构可视化和参数计算。通过实例展示了各方法的使用步骤,包括模型的参数量、FLOPs计算以及内存占用分析,帮助理解模型复杂性和资源需求。


这里主要介绍pytorch 模型的网络结构的可视化
以 SRCNN 为例子来说明可视化的方法,以及参数量的计算

模型所占内存 = (参数量内存,特征图内存),
模型计算量 = (浮点数计算量)

1. torchsummary

torchinfo

class SRCNN(nn.Module):
    def __init__(self, num_channels=1):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
        self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x


from torchinfo import summary
if __name__ == "__main__":
    modelviz = SRCNN()
    # 打印模型结构
    print(modelviz)
    summary(modelviz, input_size=(8, 1, 8, 8), col_names=["kernel_size", "output_size", "num_params", "mult_adds"])
    for p in modelviz.parameters():
        if p.requires_grad
### PyTorch 模型可视化方法 PyTorch 提供了多种工具方法来实现模型可视化,帮助开发者理解神经网络的内部机制。这些方法涵盖了模型结构的静态展示、特征图的动态观察以及模型复杂度的分析等多个方面。 #### 1. 使用 `torchsummary` 查看模型结构参数量 `torchsummary` 是一个非常方便的工具,可以用来查看模型的结构参数量。它会输出每一层的名称、输出形状以及参数数量。 ```python from torchsummary import summary # 假设 model 是已经定义好的 PyTorch 模型 summary(model, input_size=(1, 28, 28)) ``` 这种方法适用于快速了解模型的整体结构参数分布,但不支持动态计算图的可视化。 #### 2. 使用 `torchviz` `Graphviz` 可视化计算图 `torchviz` 是一个基于 `Graphviz` 的库,可以用来生成模型的计算图,展示模型中各个操作之间的依赖关系。 ```python import torch from torchviz import make_dot # 假设 model 是已经定义好的 PyTorch 模型,并且有一个输入张量 x = torch.randn(1, 1, 28, 28) y = model(x) dot = make_dot(y, params=dict(model.named_parameters())) dot.render("model_graph") # 保存为 PDF 或其他格式 ``` 这种方法能够生成详细的计算图,适用于理解模型内部的操作流程。 #### 3. 使用 `Netron` 可视化模型结构 `Netron` 是一个支持多种深度学习框架的模型可视化工具,可以通过加载 `.pt` 或 `.onnx` 文件来查看模型的结构。 ```python # 保存模型为 .pt 文件 torch.save(model.state_dict(), "model.pth") ``` 保存好模型后,可以直接在 `Netron` 的网页版或桌面版中打开模型文件,查看详细的模型结构层信息。 #### 4. 使用 `TensorBoard` 可视化模型结构 `TensorBoard` 是 PyTorch 提供的一个强大的可视化工具,可以通过 `SummaryWriter` 来记录模型的结构数据流。 ```python from torch.utils.tensorboard import SummaryWriter # 假设 model 是已经定义好的 PyTorch 模型 dummy_input = torch.rand(1, 1, 28, 28) writer = SummaryWriter() writer.add_graph(model, dummy_input) writer.close() ``` 运行完上述代码后,可以通过启动 `TensorBoard` 服务来查看模型的结构数据流。 #### 5. 使用 `get_model_complexity_info` 计算 FLOPs Parameters 对于需要分析模型计算复杂度的情况,可以使用 `get_model_complexity_info` 工具来计算模型FLOPs 参数量。 ```python from thop import profile # 假设 model 是已经定义好的 PyTorch 模型 input = torch.randn(1, 1, 28, 28) flops, params = profile(model, inputs=(input,)) print(f"FLOPs: {flops}, Parameters: {params}") ``` 这种方法适用于评估模型的计算效率资源消耗。 #### 6. 使用 `hiddenlayer` 进行高级可视化 `hiddenlayer` 是一个功能丰富的可视化库,支持多种深度学习框架,能够生成交互式的模型结构图。 ```python import hiddenlayer as hl # 假设 model 是已经定义好的 PyTorch 模型 graph = hl.build_graph(model, torch.zeros([1, 1, 28, 28])) graph.save("model_graph.png") ``` 这种方法适用于需要生成高质量模型结构图的场景。 #### 7. 使用 `TensorWatch` 进行实时可视化 `TensorWatch` 是微软开发的一个用于 PyTorch 的调试可视化工具,支持实时监控模型的训练过程。 ```python import tensorwatch as tw # 创建一个监视器 watcher = tw.Watcher() # 监视模型的参数 watcher.watch(model) ``` 通过 `TensorWatch`,可以实时查看模型的参数变化训练过程中的其他指标。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值