文章目录
这里主要介绍pytorch 模型的网络结构的可视化
以 SRCNN 为例子来说明可视化的方法,以及参数量的计算
模型所占内存 = (参数量内存,特征图内存),
模型计算量 = (浮点数计算量)
1. torchsummary
class SRCNN(nn.Module):
def __init__(self, num_channels=1):
super(SRCNN, self).__init__()
self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.conv3(x)
return x
from torchinfo import summary
if __name__ == "__main__":
modelviz = SRCNN()
# 打印模型结构
print(modelviz)
summary(modelviz, input_size=(8, 1, 8, 8), col_names=["kernel_size", "output_size", "num_params", "mult_adds"])
for p in modelviz.parameters():
if p.requires_grad:
print(p.shape)
可以得到的结果如下
具体什么含义呢?
接下来详细解释:
这里输入以 input_size=(8, 1, 8, 8) 为例子,
1) kernel shape 和 output shape 就是滤波器的参数shape 和 中间层的一些输出的 shape
2) Para # 表示的是有多少个参数,计算conv-2d 1-1的参数量,kernelshape = [9,9]:
W + b = 5248
9*9*64 + 64 = 5248
3) Multi-Adds : 统计的是浮点数运算, 计算conv-2d 1-1的计算量(浮点数运算次数):
filter(h, w, bias, channel), input(h, w, channel)
(9*9 + 1) * 64 * (8 * 8 * 8) = 2686976
4) Total params, Total mult-adds (M) 就是对 上面参数的求和
比如 5248+51232+801 = 57281
5)关于size:统计的是 参数 加上 中间层的 占用内存
输入内存Input size (MB): 0.00
是 8*1*8*8 * 4 / 1000000, 8*1*8*8 个float,每个4Byte, 除以一百万 ,约等于 0
中间特征内存Forward/backward pass size (MB): 0.40
8*8*8 * (1+64+64+32+32+1) = 99328
99328 * 4 / 1000000 = 0.397312
参数weight内存Params size (MB): 0.23
57281*4 / 1000000 = 0.229124
总内存Estimated Total Size (MB): 0.63
0.4 + 0.23
2. graphviz, torchviz
from torchviz import make_dot
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
modelviz = SRCNN().to(device)
input = torch.rand(8, 1, 8, 8).to(device)
out = modelviz(input)
print(out.shape)
# 1. 使用 torchviz 可视化
g = make_dot(out)
g.view() # 直接在当前路径下保存 pdf 并打开
# g.render(filename='netStructure/myNetModel', view=False, format='pdf') # 保存 pdf 到指定路径不打开
可视化结果是一个pdf,如下:写了比较多的步骤,所以网络结构感觉不是很清晰
3. 保存成pt文件后使用netron可视化
netron github:
安装:
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple netron
代码:
torch.save(modelviz, "modelviz.pt")
import netron
modelData = 'modelviz.pt'
netron.start(modelData)
点击链接在浏览器中打开
4. tensorwatch
import tensorwatch as tw
# 3. 使用tensorwatch可视化
print(tw.model_stats(modelviz, (8, 1, 8, 8)))
tw.draw_model(modelviz, input)
打印的结果如图,可以和 summary 进行对比
5. get_model_complexity_info计算 FLOPs和parameters
# 4. get_model_complexity_info
from ptflops import get_model_complexity_info
macs, params = get_model_complexity_info(modelviz, ( 1, 8, 8), verbose=True, print_per_layer_stat=True)
print(macs, params)
params = float(params[:-3])
macs = float(macs[:-4])
print(macs * 8, params) # 8个图像的 FLOPs, 这里的结果 和 其他方法应该一致
结果:
6. 附上直接可以执行的code
from torch import nn
import torch
from torchviz import make_dot
import tensorwatch as tw
from torchinfo import summary
import netron
class SRCNN(nn.Module):
def __init__(self, num_channels=1):
super(SRCNN, self).__init__()
self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.conv3(x)
return x
if __name__ == "__main__":
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
modelviz = SRCNN().to(device)
# 打印模型结构
print(modelviz)
summary(modelviz, input_size=(8, 1, 8, 8), col_names=["kernel_size", "output_size", "num_params", "mult_adds"])
for p in modelviz.parameters():
if p.requires_grad:
print(p.shape)
# 创建输入, 看看输出结果
input = torch.rand(8, 1, 8, 8).to(device)
out = modelviz(input)
print('out:', out.shape)
# 1. 使用 torchviz 可视化
g = make_dot(out)
g.view() # 直接在当前路径下保存 pdf 并打开
# g.render(filename='netStructure/myNetModel', view=False, format='pdf') # 保存 pdf 到指定路径不打开
# 2. 保存成pt文件后进行可视化
torch.save(modelviz, "modelviz.pt")
modelData = 'modelviz.pt'
netron.start(modelData)
# 3. 使用tensorwatch可视化
# print(tw.model_stats(modelviz, (8, 1, 8, 8)))
# tw.draw_model(modelviz, input)
# 4. get_model_complexity_info
from ptflops import get_model_complexity_info
macs, params = get_model_complexity_info(modelviz, (1, 8, 8), verbose=True, print_per_layer_stat=True)
print(macs, params)
params = float(params[:-3])
macs = float(macs[:-4])
print(macs * 8, params) # 8个图像的 FLOPs, 这里的结果 和 其他方法应该一致
7. 参考
超实用的7种 pytorch 网络可视化方法,进来收藏一波
使用pytorchviz和Netron可视化pytorch网络结构
https://cloud.tencent.com/developer/article/1842049