TorchSummary无法载入Dict数据类型解决方法 - kenvision - 博客园
torchsummary是一个比较不错的评价网络数据结构的包,但是目前torchsummary只支持网络输入和输出为torch.Tensor类型的网络,在对一些较为复杂的网络模型中,载入的可能并不一定为tensor类型,也有可能是list或者dict类型的数据。
在train.py中的适当位置加入以下代码:
from mysummary import summary summary(model, torch.rand(1,3, 480,640).cuda())
新建mysummary.py 内容:
import torch
import torch.nn as nn
from torch.autograd import Variable
from collections import OrderedDict
import numpy as np
def summary(model, data, batch_size=-1, device="cuda"):
"""
from torchsummary import summary, change it for dict input
"""
def register_hook(module):
def hook(module, input, output):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
module_idx = len(summary)
m_key = "%s-%i" % (class_name, module_idx + 1)
summary[m_key] = OrderedDict()
if isinstance(input, (list, tuple)):
# this is a sequential module for hook
summary[m_key]["input_shape"] = list()
# record input shape
if isinstance(input[0], torch.Tensor):
input = input[0]
else:
for l_i in input[0]:
summary[m_key]["input_shape"].append(l_i.size())
if isinstance(input, torch.Tensor):
summary[m_key]["input_shape"] = list(input.size())
# the dict input wasn't a issues for me
# if have some bugs, try fixed it.
# if isinstance(input, dict):
# summary[m_key]["input_shape"] = input[0].size()
summary[m_key]["input_shape"][0] = batch_size
if isinstance(output, (list, tuple)):
summary[m_key]["output_shape"] = [
[batch_size] + list(o.size())[1:] for o in output
]
elif isinstance(output, dict):
summary[m_key]["output_shape"] = [k for k in output.keys()]
else:
summary[m_key]["output_shape"] = list(output.size())
summary[m_key]["output_shape"][0] = batch_size
params = 0
if hasattr(module, "weight") and hasattr(module.weight, "size"):
params += torch.prod(torch.LongTensor(list(module.weight.size())))
summary[m_key]["trainable"] = module.weight.requires_grad
if hasattr(module, "bias") and hasattr(module.bias, "size"):
params += torch.prod(torch.LongTensor(list(module.bias.size())))
summary[m_key]["nb_params"] = params
if (
not isinstance(module, nn.Sequential)
and not isinstance(module, nn.ModuleList)
and not (module == model)
):
hooks.append(module.register_forward_hook(hook))
device = device.lower()
assert device in [
"cuda",
"cpu",
], "Input device is not valid, please specify 'cuda' or 'cpu'"
if device == "cuda" and torch.cuda.is_available():
dtype = torch.cuda.FloatTensor
else:
dtype = torch.FloatTensor
# your need create your self input data before you call this function
x = data
input_size = []
# get input shape
if isinstance(x, torch.Tensor):
input_size = data.size()
if isinstance(x, (list, dict)):
input_size = list(data.values())[0].size()
if batch_size == -1:
batch_size = input_size[0]
input_size = input_size[1:]
# print(type(x[0]))
# create properties
summary = OrderedDict()
hooks = []
# make a forward pass
# my some net block need get the input shape then
# to create the linear layer, so i need inject data before hook
# print(x.shape)
model(x)
# some model need initialization after first forward
# register hook
model.apply(register_hook)
model(x)
# remove these hooks
for h in hooks:
h.remove()
print("--------------------------------------------------------------------------")
line_new = "{:>25} {:>30} {:>15}".format("Layer (type)", "Output Shape", "Param #")
print(line_new)
print("==========================================================================")
total_params = 0
total_output = 0
trainable_params = 0
for layer in summary:
# input_shape, output_shape, trainable, nb_params
total_params += summary[layer]["nb_params"]
# total_output += np.prod(summary[layer]["output_shape"])
output_shape = summary[layer]["output_shape"]
if isinstance(summary[layer]["output_shape"][0], list):
output_shape = ""
for out_shape_list in summary[layer]["output_shape"]:
output_shape = f"{output_shape} {out_shape_list}"
if isinstance(summary[layer]['output_shape'][-1], int):
total_output = summary[layer]['output_shape']
if "trainable" in summary[layer]:
if summary[layer]["trainable"] == True:
trainable_params += summary[layer]["nb_params"]
line_new = "{:>25} {:>30} {:>15}".format(
layer,
str(output_shape),
"{0:,}".format(summary[layer]["nb_params"]),
)
print(line_new)
# assume 4 bytes/number (float on cuda).
total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
total_output_size = abs(2. * np.prod(total_output) * 4. / (1024 ** 2.)) # x2 for gradients
total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
total_size = total_params_size + total_output_size + total_input_size
print("==========================================================================")
print("Total params: {0:,}".format(total_params))
print("Trainable params: {0:,}".format(trainable_params))
print("Non-trainable params: {0:,}".format(total_params - trainable_params))
print("--------------------------------------------------------------------------")
print("Input size (MB): %0.2f" % total_input_size)
print("Forward/backward pass size (MB): %0.2f" % total_output_size)
print("Params size (MB): %0.2f" % total_params_size)
print("Estimated Total Size (MB): %0.2f" % total_size)
print("--------------------------------------------------------------------------")
# return summary
结果输出:
--------------------------------------------------------------------------
Layer (type) Output Shape Param #
==========================================================================
Conv2d-1 [1, 32, 240, 320] 3,456
BatchNorm2d-2 [1, 32, 240, 320] 64
SiLU-3 [1, 32, 240, 320] 0
Conv-4 [1, 32, 240, 320] 0
Conv2d-5 [1, 64, 120, 160] 18,432
BatchNorm2d-6 [1, 64, 120, 160] 128
SiLU-7 [1, 64, 120, 160] 0
Conv-8 [1, 64, 120, 160] 0
Conv2d-9 [1, 32, 120, 160] 2,048
BatchNorm2d-10 [1, 32, 120, 160] 64
SiLU-11 [1, 32, 120, 160] 0
Conv-12 [1, 32, 120, 160] 0
Conv2d-13 [1, 32, 120, 160] 1,024
BatchNorm2d-14 [1, 32, 120, 160] 64
SiLU-15 [1, 32, 120, 160] 0
Conv-16 [1, 32, 120, 160] 0
Conv2d-17 [1, 32, 120, 160] 9,216
BatchNorm2d-18 [1, 32, 120, 160] 64
SiLU-19 [1, 32, 120, 160] 0
Conv-20 [1, 32, 120, 160] 0
Bottleneck-21 [1, 32, 120, 160] 0
Conv2d-22 [1, 32, 120, 160] 2,048
BatchNorm2d-23 [1, 32, 120, 160] 64
SiLU-24 [1, 32, 120, 160] 0
Conv-25 [1, 32, 120, 160] 0
Conv2d-26 [1, 64, 120, 160] 4,096
BatchNorm2d-27 [1, 64, 120, 160] 128
SiLU-28 [1, 64, 120, 160] 0
Conv-29 [1, 64, 120, 160] 0
C3-30 [1, 64, 120, 160] 0
Conv2d-31 [1, 128, 60, 80] 73,728
BatchNorm2d-32 [1, 128, 60, 80] 256
SiLU-33 [1, 128, 60, 80] 0
Conv-34 [1, 128, 60, 80] 0
Conv2d-35 [1, 64, 60, 80] 8,192
BatchNorm2d-36 [1, 64, 60, 80] 128
SiLU-37 [1, 64, 60, 80] 0
Conv-38 [1, 64, 60, 80] 0
Conv2d-39 [1, 64, 60, 80] 4,096
BatchNorm2d-40 [1, 64, 60, 80] 128
SiLU-41 [1, 64, 60, 80] 0
Conv-42 [1, 64, 60, 80] 0
Conv2d-43 [1, 64, 60, 80] 36,864
BatchNorm2d-44 [1, 64, 60, 80] 128
SiLU-45 [1, 64, 60, 80] 0
Conv-46 [1, 64, 60, 80] 0
Bottleneck-47 [1, 64, 60, 80] 0
Conv2d-48 [1, 64, 60, 80] 4,096
BatchNorm2d-49 [1, 64, 60, 80] 128
SiLU-50 [1, 64, 60, 80] 0
Conv-51 [1, 64, 60, 80] 0
Conv2d-52 [1, 64, 60, 80] 36,864
BatchNorm2d-53 [1, 64, 60, 80] 128
SiLU-54 [1, 64, 60, 80] 0
Conv-55 [1, 64, 60, 80] 0
Bottleneck-56 [1, 64, 60, 80] 0
Conv2d-57 [1, 64, 60, 80] 8,192
BatchNorm2d-58 [1, 64, 60, 80] 128
SiLU-59 [1, 64, 60, 80] 0
Conv-60 [1, 64, 60, 80] 0
Conv2d-61 [1, 128, 60, 80] 16,384
BatchNorm2d-62 [1, 128, 60, 80] 256
SiLU-63 [1, 128, 60, 80] 0
Conv-64 [1, 128, 60, 80] 0
C3-65 [1, 128, 60, 80] 0
Conv2d-66 [1, 256, 30, 40] 294,912
BatchNorm2d-67 [1, 256, 30, 40] 512
SiLU-68 [1, 256, 30, 40] 0
Conv-69 [1, 256, 30, 40] 0
Conv2d-70 [1, 128, 30, 40] 32,768
BatchNorm2d-71 [1, 128, 30, 40] 256
SiLU-72 [1, 128, 30, 40] 0
Conv-73 [1, 128, 30, 40] 0
Conv2d-74 [1, 128, 30, 40] 16,384
BatchNorm2d-75 [1, 128, 30, 40] 256
SiLU-76 [1, 128, 30, 40] 0
Conv-77 [1, 128, 30, 40] 0
Conv2d-78 [1, 128, 30, 40] 147,456
BatchNorm2d-79 [1, 128, 30, 40] 256
SiLU-80 [1, 128, 30, 40] 0
Conv-81 [1, 128, 30, 40] 0
Bottleneck-82 [1, 128, 30, 40] 0
Conv2d-83 [1, 128, 30, 40] 16,384
BatchNorm2d-84 [1, 128, 30, 40] 256
SiLU-85 [1, 128, 30, 40] 0
Conv-86 [1, 128, 30, 40] 0
Conv2d-87 [1, 128, 30, 40] 147,456
BatchNorm2d-88 [1, 128, 30, 40] 256
SiLU-89 [1, 128, 30, 40] 0
Conv-90 [1, 128, 30, 40] 0
Bottleneck-91 [1, 128, 30, 40] 0
Conv2d-92 [1, 128, 30, 40] 16,384
BatchNorm2d-93 [1, 128, 30, 40] 256
SiLU-94 [1, 128, 30, 40] 0
Conv-95 [1, 128, 30, 40] 0
Conv2d-96 [1, 128, 30, 40] 147,456
BatchNorm2d-97 [1, 128, 30, 40] 256
SiLU-98 [1, 128, 30, 40] 0
Conv-99 [1, 128, 30, 40] 0
Bottleneck-100 [1, 128, 30, 40] 0
Conv2d-101 [1, 128, 30, 40] 32,768
BatchNorm2d-102 [1, 128, 30, 40] 256
SiLU-103 [1, 128, 30, 40] 0
Conv-104 [1, 128, 30, 40] 0
Conv2d-105 [1, 256, 30, 40] 65,536
BatchNorm2d-106 [1, 256, 30, 40] 512
SiLU-107 [1, 256, 30, 40] 0
Conv-108 [1, 256, 30, 40] 0
C3-109 [1, 256, 30, 40] 0
Conv2d-110 [1, 512, 15, 20] 1,179,648
BatchNorm2d-111 [1, 512, 15, 20] 1,024
SiLU-112 [1, 512, 15, 20] 0
Conv-113 [1, 512, 15, 20] 0
Conv2d-114 [1, 256, 15, 20] 131,072
BatchNorm2d-115 [1, 256, 15, 20] 512
SiLU-116 [1, 256, 15, 20] 0
Conv-117 [1, 256, 15, 20] 0
Conv2d-118 [1, 256, 15, 20] 65,536
BatchNorm2d-119 [1, 256, 15, 20] 512
SiLU-120 [1, 256, 15, 20] 0
Conv-121 [1, 256, 15, 20] 0
Conv2d-122 [1, 256, 15, 20] 589,824
BatchNorm2d-123 [1, 256, 15, 20] 512
SiLU-124 [1, 256, 15, 20] 0
Conv-125 [1, 256, 15, 20] 0
Bottleneck-126 [1, 256, 15, 20] 0
Conv2d-127 [1, 256, 15, 20] 131,072
BatchNorm2d-128 [1, 256, 15, 20] 512
SiLU-129 [1, 256, 15, 20] 0
Conv-130 [1, 256, 15, 20] 0
Conv2d-131 [1, 512, 15, 20] 262,144
BatchNorm2d-132 [1, 512, 15, 20] 1,024
SiLU-133 [1, 512, 15, 20] 0
Conv-134 [1, 512, 15, 20] 0
C3-135 [1, 512, 15, 20] 0
Conv2d-136 [1, 256, 15, 20] 131,072
BatchNorm2d-137 [1, 256, 15, 20] 512
SiLU-138 [1, 256, 15, 20] 0
Conv-139 [1, 256, 15, 20] 0
MaxPool2d-140 [1, 256, 15, 20] 0
MaxPool2d-141 [1, 256, 15, 20] 0
MaxPool2d-142 [1, 256, 15, 20] 0
Conv2d-143 [1, 512, 15, 20] 524,288
BatchNorm2d-144 [1, 512, 15, 20] 1,024
SiLU-145 [1, 512, 15, 20] 0
Conv-146 [1, 512, 15, 20] 0
SPPF-147 [1, 512, 15, 20] 0
Conv2d-148 [1, 256, 15, 20] 131,072
BatchNorm2d-149 [1, 256, 15, 20] 512
SiLU-150 [1, 256, 15, 20] 0
Conv-151 [1, 256, 15, 20] 0
Upsample-152 [1, 256, 30, 40] 0
Concat-153 [1, 512, 30, 40] 0
Conv2d-154 [1, 128, 30, 40] 65,536
BatchNorm2d-155 [1, 128, 30, 40] 256
SiLU-156 [1, 128, 30, 40] 0
Conv-157 [1, 128, 30, 40] 0
Conv2d-158 [1, 128, 30, 40] 16,384
BatchNorm2d-159 [1, 128, 30, 40] 256
SiLU-160 [1, 128, 30, 40] 0
Conv-161 [1, 128, 30, 40] 0
Conv2d-162 [1, 128, 30, 40] 147,456
BatchNorm2d-163 [1, 128, 30, 40] 256
SiLU-164 [1, 128, 30, 40] 0
Conv-165 [1, 128, 30, 40] 0
Bottleneck-166 [1, 128, 30, 40] 0
Conv2d-167 [1, 128, 30, 40] 65,536
BatchNorm2d-168 [1, 128, 30, 40] 256
SiLU-169 [1, 128, 30, 40] 0
Conv-170 [1, 128, 30, 40] 0
Conv2d-171 [1, 256, 30, 40] 65,536
BatchNorm2d-172 [1, 256, 30, 40] 512
SiLU-173 [1, 256, 30, 40] 0
Conv-174 [1, 256, 30, 40] 0
C3-175 [1, 256, 30, 40] 0
Conv2d-176 [1, 128, 30, 40] 32,768
BatchNorm2d-177 [1, 128, 30, 40] 256
SiLU-178 [1, 128, 30, 40] 0
Conv-179 [1, 128, 30, 40] 0
Upsample-180 [1, 128, 60, 80] 0
Concat-181 [1, 256, 60, 80] 0
Conv2d-182 [1, 64, 60, 80] 16,384
BatchNorm2d-183 [1, 64, 60, 80] 128
SiLU-184 [1, 64, 60, 80] 0
Conv-185 [1, 64, 60, 80] 0
Conv2d-186 [1, 64, 60, 80] 4,096
BatchNorm2d-187 [1, 64, 60, 80] 128
SiLU-188 [1, 64, 60, 80] 0
Conv-189 [1, 64, 60, 80] 0
Conv2d-190 [1, 64, 60, 80] 36,864
BatchNorm2d-191 [1, 64, 60, 80] 128
SiLU-192 [1, 64, 60, 80] 0
Conv-193 [1, 64, 60, 80] 0
Bottleneck-194 [1, 64, 60, 80] 0
Conv2d-195 [1, 64, 60, 80] 16,384
BatchNorm2d-196 [1, 64, 60, 80] 128
SiLU-197 [1, 64, 60, 80] 0
Conv-198 [1, 64, 60, 80] 0
Conv2d-199 [1, 128, 60, 80] 16,384
BatchNorm2d-200 [1, 128, 60, 80] 256
SiLU-201 [1, 128, 60, 80] 0
Conv-202 [1, 128, 60, 80] 0
C3-203 [1, 128, 60, 80] 0
Conv2d-204 [1, 128, 30, 40] 147,456
BatchNorm2d-205 [1, 128, 30, 40] 256
SiLU-206 [1, 128, 30, 40] 0
Conv-207 [1, 128, 30, 40] 0
Concat-208 [1, 256, 30, 40] 0
Conv2d-209 [1, 128, 30, 40] 32,768
BatchNorm2d-210 [1, 128, 30, 40] 256
SiLU-211 [1, 128, 30, 40] 0
Conv-212 [1, 128, 30, 40] 0
Conv2d-213 [1, 128, 30, 40] 16,384
BatchNorm2d-214 [1, 128, 30, 40] 256
SiLU-215 [1, 128, 30, 40] 0
Conv-216 [1, 128, 30, 40] 0
Conv2d-217 [1, 128, 30, 40] 147,456
BatchNorm2d-218 [1, 128, 30, 40] 256
SiLU-219 [1, 128, 30, 40] 0
Conv-220 [1, 128, 30, 40] 0
Bottleneck-221 [1, 128, 30, 40] 0
Conv2d-222 [1, 128, 30, 40] 32,768
BatchNorm2d-223 [1, 128, 30, 40] 256
SiLU-224 [1, 128, 30, 40] 0
Conv-225 [1, 128, 30, 40] 0
Conv2d-226 [1, 256, 30, 40] 65,536
BatchNorm2d-227 [1, 256, 30, 40] 512
SiLU-228 [1, 256, 30, 40] 0
Conv-229 [1, 256, 30, 40] 0
C3-230 [1, 256, 30, 40] 0
Conv2d-231 [1, 256, 15, 20] 589,824
BatchNorm2d-232 [1, 256, 15, 20] 512
SiLU-233 [1, 256, 15, 20] 0
Conv-234 [1, 256, 15, 20] 0
Concat-235 [1, 512, 15, 20] 0
Conv2d-236 [1, 256, 15, 20] 131,072
BatchNorm2d-237 [1, 256, 15, 20] 512
SiLU-238 [1, 256, 15, 20] 0
Conv-239 [1, 256, 15, 20] 0
Conv2d-240 [1, 256, 15, 20] 65,536
BatchNorm2d-241 [1, 256, 15, 20] 512
SiLU-242 [1, 256, 15, 20] 0
Conv-243 [1, 256, 15, 20] 0
Conv2d-244 [1, 256, 15, 20] 589,824
BatchNorm2d-245 [1, 256, 15, 20] 512
SiLU-246 [1, 256, 15, 20] 0
Conv-247 [1, 256, 15, 20] 0
Bottleneck-248 [1, 256, 15, 20] 0
Conv2d-249 [1, 256, 15, 20] 131,072
BatchNorm2d-250 [1, 256, 15, 20] 512
SiLU-251 [1, 256, 15, 20] 0
Conv-252 [1, 256, 15, 20] 0
Conv2d-253 [1, 512, 15, 20] 262,144
BatchNorm2d-254 [1, 512, 15, 20] 1,024
SiLU-255 [1, 512, 15, 20] 0
Conv-256 [1, 512, 15, 20] 0
C3-257 [1, 512, 15, 20] 0
Conv2d-258 [1, 255, 60, 80] 32,895
Conv2d-259 [1, 255, 30, 40] 65,535
Conv2d-260 [1, 255, 15, 20] 130,815
Detect-261 [1, 3, 60, 80, 85] [1, 3, 30, 40, 85] [1, 3, 15, 20, 85] 0
==========================================================================
Total params: 7,235,389
Trainable params: 7,235,389
Non-trainable params: 0
--------------------------------------------------------------------------
Input size (MB): 3.52
Forward/backward pass size (MB): 0.58
Params size (MB): 27.60
Estimated Total Size (MB): 31.70
--------------------------------------------------------------------------