torchsummary
在 PyTorch 中,可以使用 torchsummary
库来获得类似于 TensorFlow 中 model.summary()
的清晰直观的模型摘要。torchsummary
可以显示模型的总体结构、每层的名称、输出形状以及参数数量。以下是如何使用 torchsummary
的详细说明。
安装 torchsummary
首先,需要安装 torchsummary
库:
pip install torchsummary
使用 torchsummary
显示模型摘要
下面是一个使用 torchsummary
的示例代码,展示如何在 PyTorch 中显示模型摘要:
import torch
import torch.nn as nn
from torchsummary import summary
# 定义一个简单的模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(784, 64)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, 10)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
x = self.softmax(x)
return x
# 实例化模型
model = SimpleModel()
# 使用 torchsummary 显示模型摘要
summary(model, input_size=(1, 784))
详细说明
torchsummary 可以生成类似于 TensorFlow model.summary()
的输出,显示模型每层的名称、输出形状、参数数量,以及模型的总参数数目。
以下是上述代码的输出示例:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Linear-1 [-1, 64] 50,240
ReLU-2 [-1, 64] 0
Linear-3 [-1, 64] 4,160
ReLU-4 [-1, 64] 0
Linear-5 [-1, 10] 650
Softmax-6 [-1, 10] 0
================================================================
Total params: 55,050
Trainable params: 55,050
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
总结
在 PyTorch 中,torchsummary 库提供了类似于 TensorFlow model.summary()
的功能,能够直观地显示模型结构和参数。这对于快速了解模型的组成和调试非常有用。
重点内容:
- 安装
torchsummary
库:使用pip install torchsummary
安装。 - 使用
summary()
函数:显示模型的结构和参数。
通过使用 torchsummary
,PyTorch 用户可以获得与 TensorFlow 用户类似的便捷体验,从而更高效地进行深度学习模型的开发和调试。