要打印 PyTorch 深度学习模型的方法,可以使用多种方式来查看模型的结构、参数和属性。以下是几种常用的方法:
1. 使用 print(model)
直接打印模型对象可以查看模型的结构,包括每一层的类型和参数数量。
import torch
import torch.nn as nn
# 定义一个简单的模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(64 * 28 * 28, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = x.view(x.size(0), -1)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
model = SimpleModel()
# 打印模型结构
print(model)
#输出示例
SimpleModel(
(conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1): Linear(in_features=50176, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=10, bias=True)
)
2. 使用 summary
函数
torchsummary
库提供了 summary
函数,可以打印出模型的详细信息,包括每层的输出形状和参数数量。
from torchsummary import summary
model = SimpleModel().cuda() # 确保模型在 GPU 上
# 打印模型总结信息
summary(model, (1, 28, 28)) # 输入尺寸为 (1, 28, 28)
#输出示例
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 32, 28, 28] 320
Conv2d-2 [-1, 64, 28, 28] 18,496
Linear-3 [-1, 128] 6,421,504
Linear-4 [-1, 10] 1,290
================================================================
Total params: 6,441,610
Trainable params: 6,441,610
Non-trainable params: 0
----------------------------------------------------------------
3. 使用 state_dict
state_dict
可以查看模型的所有参数(权重和偏置)。
# 打印模型的 state_dict
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
#输出示例
conv1.weight torch.Size([32, 1, 3, 3])
conv1.bias torch.Size([32])
conv2.weight torch.Size([64, 32, 3, 3])
conv2.bias torch.Size([64])
fc1.weight torch.Size([128, 50176])
fc1.bias torch.Size([128])
fc2.weight torch.Size([10, 128])
fc2.bias torch.Size([10])
5. 查看模型的各层详细信息
通过迭代模型的子模块,可以查看每一层的详细信息。
# 打印每一层的详细信息
for name, layer in model.named_children():
print(name, ":", layer)
#输出示例
conv1 : Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv2 : Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
fc1 : Linear(in_features=50176, out_features=128, bias=True)
fc2 : Linear(in_features=128, out_features=10, bias=True)