下载torchsummary和使用
一、下载
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple torchsummary

二、使用
from torchsummary import summary
from torchvision import models
net = models.mobilenet_v2()
summary(net.cuda(), input_size=(3, 112, 112))
----------------------------------------------------------------
Layer (type) Output Shape Param
================================================================
Conv2d-1 [-1, 32, 56, 56] 864
BatchNorm2d-2 [-1, 32, 56, 56] 64
ReLU6-3 [-1, 32, 56, 56] 0
Conv2d-4 [-1, 32, 56, 56] 288
BatchNorm2d-5 [-1, 32, 56, 56] 64
ReLU6-6 [-1, 32, 56, 56] 0
Conv2d-7 [-1, 16, 56, 56] 512
BatchNorm2d-8 [-1, 16, 56, 56] 32
InvertedResidual-9 [-1, 16, 56, 56] 0
Conv2d-10 [-1, 96, 56, 56] 1,536
BatchNorm2d-11 [-1, 96, 56, 56] 192
ReLU6-12 [-1, 96, 56, 56] 0
Conv2d-13 [-1, 96, 28, 28] 864
BatchNorm2d-14 [-1, 96, 28, 28] 192
ReLU6-15 [-1, 96, 28, 28] 0
Conv2d-16 [-1, 24, 28, 28] 2,304
BatchNorm2d-17 [-1, 24, 28, 28] 48
InvertedResidual-18 [-1, 24, 28, 28] 0
Conv2d-19 [-1, 144, 28, 28] 3,456
BatchNorm2d-20 [-1, 144, 28, 28] 288
ReLU6-21 [-1, 144, 28, 28] 0
Conv2d-22 [-1, 144, 28, 28] 1,296
BatchNorm2d-23 [-1, 144, 28, 28] 288
ReLU6-24 [-1, 144, 28, 28] 0
Conv2d-25 [-1, 24, 28, 28] 3,456
BatchNorm2d-26 [-1, 24, 28, 28] 48
InvertedResidual-27 [-1, 24, 28, 28] 0
Conv2d-28 [-1, 144, 28, 28] 3,456
BatchNorm2d-29 [-1, 144, 28, 28] 288
ReLU6-30 [-1, 144, 28, 28] 0
Conv2d-31 [-1, 144, 14, 14] 1,296
BatchNorm2d-32 [-1, 144, 14, 14] 288
ReLU6-33 [-1, 144, 14, 14] 0
Conv2d-34 [-1, 32, 14, 14] 4,608
BatchNorm2d-35 [-1, 32, 14, 14] 64
InvertedResidual-36 [-1, 32, 14, 14] 0
Conv2d-37 [-1, 192, 14, 14] 6,144
BatchNorm2d-38 [-1, 192, 14, 14] 384
ReLU6-39 [-1, 192, 14, 14] 0
Conv2d-40 [-1, 192, 14, 14] 1,728
BatchNorm2d-41 [-1, 192, 14, 14] 384
ReLU6-42 [-1, 192, 14, 14] 0
Conv2d-43 [-1, 32, 14, 14] 6,144
BatchNorm2d-44 [-1, 32, 14, 14] 64
InvertedResidual-45 [-1, 32, 14, 14] 0
Conv2d-46 [-1, 192, 14, 14] 6,144
BatchNorm2d-47 [-1, 192, 14, 14] 384
ReLU6-48 [-1, 192, 14, 14] 0
Conv2d-49 [-1, 192, 14, 14] 1,728
BatchNorm2d-50 [-1, 192, 14, 14] 384
ReLU6-51 [-1, 192, 14, 14] 0
Conv2d-52 [-1, 32, 14, 14] 6,144
BatchNorm2d-53 [-1, 32, 14, 14] 64
InvertedResidual-54 [-1, 32, 14, 14] 0
Conv2d-55 [-1, 192, 14, 14] 6,144
BatchNorm2d-56 [-1, 192, 14, 14] 384
ReLU6-57 [-1, 192, 14, 14] 0
Conv2d-58 [-1, 192, 7, 7] 1,728
BatchNorm2d-59 [-1, 192, 7, 7] 384
ReLU6-60 [-1, 192, 7, 7] 0
Conv2d-61 [-1, 64, 7, 7] 12,288
BatchNorm2d-62 [-1, 64, 7, 7] 128
InvertedResidual-63 [-1, 64, 7, 7] 0
Conv2d-64 [-1, 384, 7, 7] 24,576
BatchNorm2d-65 [-1, 384, 7, 7] 768
ReLU6-66 [-1, 384, 7, 7] 0
Conv2d-67 [-1, 384, 7, 7] 3,456
BatchNorm2d-68 [-1, 384, 7, 7] 768
ReLU6-69 [-1, 384, 7, 7] 0
Conv2d-70 [-1, 64, 7, 7] 24,576
BatchNorm2d-71 [-1, 64, 7, 7] 128
InvertedResidual-72 [-1, 64, 7, 7] 0
Conv2d-73 [-1, 384, 7, 7] 24,576
BatchNorm2d-74 [-1, 384, 7, 7] 768
ReLU6-75 [-1, 384, 7, 7] 0
Conv2d-76 [-1, 384, 7, 7] 3,456
BatchNorm2d-77 [-1, 384, 7, 7] 768
ReLU6-78 [-1, 384, 7, 7] 0
Conv2d-79 [-1, 64, 7, 7] 24,576
BatchNorm2d-80 [-1, 64, 7, 7] 128
InvertedResidual-81 [-1, 64, 7, 7] 0
Conv2d-82 [-1, 384, 7, 7] 24,576
BatchNorm2d-83 [-1, 384, 7, 7] 768
ReLU6-84 [-1, 384, 7, 7] 0
Conv2d-85 [-1, 384, 7, 7] 3,456
BatchNorm2d-86 [-1, 384, 7, 7] 768
ReLU6-87 [-1, 384, 7, 7] 0
Conv2d-88 [-1, 64, 7, 7] 24,576
BatchNorm2d-89 [-1, 64, 7, 7] 128
InvertedResidual-90 [-1, 64, 7, 7] 0
Conv2d-91 [-1, 384, 7, 7] 24,576
BatchNorm2d-92 [-1, 384, 7, 7] 768
ReLU6-93 [-1, 384, 7, 7] 0
Conv2d-94 [-1, 384, 7, 7] 3,456
BatchNorm2d-95 [-1, 384, 7, 7] 768
ReLU6-96 [-1, 384, 7, 7] 0
Conv2d-97 [-1, 96, 7, 7] 36,864
BatchNorm2d-98 [-1, 96, 7, 7] 192
InvertedResidual-99 [-1, 96, 7, 7] 0
Conv2d-100 [-1, 576, 7, 7] 55,296
BatchNorm2d-101 [-1, 576, 7, 7] 1,152
ReLU6-102 [-1, 576, 7, 7] 0
Conv2d-103 [-1, 576, 7, 7] 5,184
BatchNorm2d-104 [-1, 576, 7, 7] 1,152
ReLU6-105 [-1, 576, 7, 7] 0
Conv2d-106 [-1, 96, 7, 7] 55,296
BatchNorm2d-107 [-1, 96, 7, 7] 192
InvertedResidual-108 [-1, 96, 7, 7] 0
Conv2d-109 [-1, 576, 7, 7] 55,296
BatchNorm2d-110 [-1, 576, 7, 7] 1,152
ReLU6-111 [-1, 576, 7, 7] 0
Conv2d-112 [-1, 576, 7, 7] 5,184
BatchNorm2d-113 [-1, 576, 7, 7] 1,152
ReLU6-114 [-1, 576, 7, 7] 0
Conv2d-115 [-1, 96, 7, 7] 55,296
BatchNorm2d-116 [-1, 96, 7, 7] 192
InvertedResidual-117 [-1, 96, 7, 7] 0
Conv2d-118 [-1, 576, 7, 7] 55,296
BatchNorm2d-119 [-1, 576, 7, 7] 1,152
ReLU6-120 [-1, 576, 7, 7] 0
Conv2d-121 [-1, 576, 4, 4] 5,184
BatchNorm2d-122 [-1, 576, 4, 4] 1,152
ReLU6-123 [-1, 576, 4, 4] 0
Conv2d-124 [-1, 160, 4, 4] 92,160
BatchNorm2d-125 [-1, 160, 4, 4] 320
InvertedResidual-126 [-1, 160, 4, 4] 0
Conv2d-127 [-1, 960, 4, 4] 153,600
BatchNorm2d-128 [-1, 960, 4, 4] 1,920
ReLU6-129 [-1, 960, 4, 4] 0
Conv2d-130 [-1, 960, 4, 4] 8,640
BatchNorm2d-131 [-1, 960, 4, 4] 1,920
ReLU6-132 [-1, 960, 4, 4] 0
Conv2d-133 [-1, 160, 4, 4] 153,600
BatchNorm2d-134 [-1, 160, 4, 4] 320
InvertedResidual-135 [-1, 160, 4, 4] 0
Conv2d-136 [-1, 960, 4, 4] 153,600
BatchNorm2d-137 [-1, 960, 4, 4] 1,920
ReLU6-138 [-1, 960, 4, 4] 0
Conv2d-139 [-1, 960, 4, 4] 8,640
BatchNorm2d-140 [-1, 960, 4, 4] 1,920
ReLU6-141 [-1, 960, 4, 4] 0
Conv2d-142 [-1, 160, 4, 4] 153,600
BatchNorm2d-143 [-1, 160, 4, 4] 320
InvertedResidual-144 [-1, 160, 4, 4] 0
Conv2d-145 [-1, 960, 4, 4] 153,600
BatchNorm2d-146 [-1, 960, 4, 4] 1,920
ReLU6-147 [-1, 960, 4, 4] 0
Conv2d-148 [-1, 960, 4, 4] 8,640
BatchNorm2d-149 [-1, 960, 4, 4] 1,920
ReLU6-150 [-1, 960, 4, 4] 0
Conv2d-151 [-1, 320, 4, 4] 307,200
BatchNorm2d-152 [-1, 320, 4, 4] 640
InvertedResidual-153 [-1, 320, 4, 4] 0
Conv2d-154 [-1, 1280, 4, 4] 409,600
BatchNorm2d-155 [-1, 1280, 4, 4] 2,560
ReLU6-156 [-1, 1280, 4, 4] 0
Dropout-157 [-1, 1280] 0
Linear-158 [-1, 1000] 1,281,000
================================================================
Total params: 3,504,872
Trainable params: 3,504,872
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.14
Forward/backward pass size (MB): 38.95
Params size (MB): 13.37
Estimated Total Size (MB): 52.47
----------------------------------------------------------------
三、torchsummary源码
- 之前看到有其他的博主说,这个torchsummary在实际使用的时候,没有很严谨,有些内容没有写进去,然后进行了修改,如果感兴趣,自己看完这个源码可以也进行修改,可以告诉博主,我也学习一下。
import torch
import torch.nn as nn
from torch.autograd import Variable
from collections import OrderedDict
import numpy as np
def summary(model, input_size, batch_size=-1, device="cuda"):
def register_hook(module):
def hook(module, input, output):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
module_idx = len(summary)
m_key = "%s-%i" % (class_name, module_idx + 1)
summary[m_key] = OrderedDict()
summary[m_key]["input_shape"] = list(input[0].size())
summary[m_key]["input_shape"][0] = batch_size
if isinstance(output, (list, tuple)):
summary[m_key]["output_shape"] = [
[-1] + list(o.size())[1:] for o in output
]
else:
summary[m_key]["output_shape"] = list(output.size())
summary[m_key]["output_shape"][0] = batch_size
params = 0
if hasattr(module, "weight") and hasattr(module.weight, "size"):
params += torch.prod(torch.LongTensor(list(module.weight.size())))
summary[m_key]["trainable"] = module.weight.requires_grad
if hasattr(module, "bias") and hasattr(module.bias, "size"):
params += torch.prod(torch.LongTensor(list(module.bias.size())))
summary[m_key]["nb_params"] = params
if (
not isinstance(module, nn.Sequential)
and not isinstance(module, nn.ModuleList)
and not (module == model)
):
hooks.append(module.register_forward_hook(hook))
device = device.lower()
assert device in [
"cuda",
"cpu",
], "Input device is not valid, please specify 'cuda' or 'cpu'"
if device == "cuda" and torch.cuda.is_available():
dtype = torch.cuda.FloatTensor
else:
dtype = torch.FloatTensor
if isinstance(input_size, tuple):
input_size = [input_size]
x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
summary = OrderedDict()
hooks = []
model.apply(register_hook)
model(*x)
for h in hooks:
h.remove()
print("----------------------------------------------------------------")
line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
print(line_new)
print("================================================================")
total_params = 0
total_output = 0
trainable_params = 0
for layer in summary:
line_new = "{:>20} {:>25} {:>15}".format(
layer,
str(summary[layer]["output_shape"]),
"{0:,}".format(summary[layer]["nb_params"]),
)
total_params += summary[layer]["nb_params"]
total_output += np.prod(summary[layer]["output_shape"])
if "trainable" in summary[layer]:
if summary[layer]["trainable"] == True:
trainable_params += summary[layer]["nb_params"]
print(line_new)
total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))
total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
total_size = total_params_size + total_output_size + total_input_size
print("================================================================")
print("Total params: {0:,}".format(total_params))
print("Trainable params: {0:,}".format(trainable_params))
print("Non-trainable params: {0:,}".format(total_params - trainable_params))
print("----------------------------------------------------------------")
print("Input size (MB): %0.2f" % total_input_size)
print("Forward/backward pass size (MB): %0.2f" % total_output_size)
print("Params size (MB): %0.2f" % total_params_size)
print("Estimated Total Size (MB): %0.2f" % total_size)
print("----------------------------------------------------------------")