文章目录
这里主要介绍pytorch 模型的网络结构的可视化
以 SRCNN 为例子来说明可视化的方法,以及参数量的计算
模型所占内存 = (参数量内存,特征图内存),
模型计算量 = (浮点数计算量)
1. torchsummary
class SRCNN(nn.Module):
def __init__(self, num_channels=1):
super(SRCNN, self).__init__()
self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.conv3(x)
return x
from torchinfo import summary
if __name__ == "__main__":
modelviz = SRCNN()
# 打印模型结构
print(modelviz)
summary(modelviz, input_size=(8, 1, 8, 8), col_names=["kernel_size", "output_size", "num_params", "mult_adds"])
for p in modelviz.parameters():
if p.requires_grad

本文介绍了如何使用torchsummary、torchviz、netron、tensorwatch和get_model_complexity_info等工具对PyTorch的SRCNN模型进行网络结构可视化和参数计算。通过实例展示了各方法的使用步骤,包括模型的参数量、FLOPs计算以及内存占用分析,帮助理解模型复杂性和资源需求。
最低0.47元/天 解锁文章
1万+

被折叠的 条评论
为什么被折叠?



