[Download-Tools]2.下载torchsummary和使用

下载torchsummary和使用

一、下载

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple torchsummary

在这里插入图片描述

二、使用

from torchsummary import summary
from torchvision import models

net = models.mobilenet_v2()
summary(net.cuda(), input_size=(3, 112, 112))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 56, 56]             864
       BatchNorm2d-2           [-1, 32, 56, 56]              64
             ReLU6-3           [-1, 32, 56, 56]               0
            Conv2d-4           [-1, 32, 56, 56]             288
       BatchNorm2d-5           [-1, 32, 56, 56]              64
             ReLU6-6           [-1, 32, 56, 56]               0
            Conv2d-7           [-1, 16, 56, 56]             512
       BatchNorm2d-8           [-1, 16, 56, 56]              32
  InvertedResidual-9           [-1, 16, 56, 56]               0
           Conv2d-10           [-1, 96, 56, 56]           1,536
      BatchNorm2d-11           [-1, 96, 56, 56]             192
            ReLU6-12           [-1, 96, 56, 56]               0
           Conv2d-13           [-1, 96, 28, 28]             864
      BatchNorm2d-14           [-1, 96, 28, 28]             192
            ReLU6-15           [-1, 96, 28, 28]               0
           Conv2d-16           [-1, 24, 28, 28]           2,304
      BatchNorm2d-17           [-1, 24, 28, 28]              48
 InvertedResidual-18           [-1, 24, 28, 28]               0
           Conv2d-19          [-1, 144, 28, 28]           3,456
      BatchNorm2d-20          [-1, 144, 28, 28]             288
            ReLU6-21          [-1, 144, 28, 28]               0
           Conv2d-22          [-1, 144, 28, 28]           1,296
      BatchNorm2d-23          [-1, 144, 28, 28]             288
            ReLU6-24          [-1, 144, 28, 28]               0
           Conv2d-25           [-1, 24, 28, 28]           3,456
      BatchNorm2d-26           [-1, 24, 28, 28]              48
 InvertedResidual-27           [-1, 24, 28, 28]               0
           Conv2d-28          [-1, 144, 28, 28]           3,456
      BatchNorm2d-29          [-1, 144, 28, 28]             288
            ReLU6-30          [-1, 144, 28, 28]               0
           Conv2d-31          [-1, 144, 14, 14]           1,296
      BatchNorm2d-32          [-1, 144, 14, 14]             288
            ReLU6-33          [-1, 144, 14, 14]               0
           Conv2d-34           [-1, 32, 14, 14]           4,608
      BatchNorm2d-35           [-1, 32, 14, 14]              64
 InvertedResidual-36           [-1, 32, 14, 14]               0
           Conv2d-37          [-1, 192, 14, 14]           6,144
      BatchNorm2d-38          [-1, 192, 14, 14]             384
            ReLU6-39          [-1, 192, 14, 14]               0
           Conv2d-40          [-1, 192, 14, 14]           1,728
      BatchNorm2d-41          [-1, 192, 14, 14]             384
            ReLU6-42          [-1, 192, 14, 14]               0
           Conv2d-43           [-1, 32, 14, 14]           6,144
      BatchNorm2d-44           [-1, 32, 14, 14]              64
 InvertedResidual-45           [-1, 32, 14, 14]               0
           Conv2d-46          [-1, 192, 14, 14]           6,144
      BatchNorm2d-47          [-1, 192, 14, 14]             384
            ReLU6-48          [-1, 192, 14, 14]               0
           Conv2d-49          [-1, 192, 14, 14]           1,728
      BatchNorm2d-50          [-1, 192, 14, 14]             384
            ReLU6-51          [-1, 192, 14, 14]               0
           Conv2d-52           [-1, 32, 14, 14]           6,144
      BatchNorm2d-53           [-1, 32, 14, 14]              64
 InvertedResidual-54           [-1, 32, 14, 14]               0
           Conv2d-55          [-1, 192, 14, 14]           6,144
      BatchNorm2d-56          [-1, 192, 14, 14]             384
            ReLU6-57          [-1, 192, 14, 14]               0
           Conv2d-58            [-1, 192, 7, 7]           1,728
      BatchNorm2d-59            [-1, 192, 7, 7]             384
            ReLU6-60            [-1, 192, 7, 7]               0
           Conv2d-61             [-1, 64, 7, 7]          12,288
      BatchNorm2d-62             [-1, 64, 7, 7]             128
 InvertedResidual-63             [-1, 64, 7, 7]               0
           Conv2d-64            [-1, 384, 7, 7]          24,576
      BatchNorm2d-65            [-1, 384, 7, 7]             768
            ReLU6-66            [-1, 384, 7, 7]               0
           Conv2d-67            [-1, 384, 7, 7]           3,456
      BatchNorm2d-68            [-1, 384, 7, 7]             768
            ReLU6-69            [-1, 384, 7, 7]               0
           Conv2d-70             [-1, 64, 7, 7]          24,576
      BatchNorm2d-71             [-1, 64, 7, 7]             128
 InvertedResidual-72             [-1, 64, 7, 7]               0
           Conv2d-73            [-1, 384, 7, 7]          24,576
      BatchNorm2d-74            [-1, 384, 7, 7]             768
            ReLU6-75            [-1, 384, 7, 7]               0
           Conv2d-76            [-1, 384, 7, 7]           3,456
      BatchNorm2d-77            [-1, 384, 7, 7]             768
            ReLU6-78            [-1, 384, 7, 7]               0
           Conv2d-79             [-1, 64, 7, 7]          24,576
      BatchNorm2d-80             [-1, 64, 7, 7]             128
 InvertedResidual-81             [-1, 64, 7, 7]               0
           Conv2d-82            [-1, 384, 7, 7]          24,576
      BatchNorm2d-83            [-1, 384, 7, 7]             768
            ReLU6-84            [-1, 384, 7, 7]               0
           Conv2d-85            [-1, 384, 7, 7]           3,456
      BatchNorm2d-86            [-1, 384, 7, 7]             768
            ReLU6-87            [-1, 384, 7, 7]               0
           Conv2d-88             [-1, 64, 7, 7]          24,576
      BatchNorm2d-89             [-1, 64, 7, 7]             128
 InvertedResidual-90             [-1, 64, 7, 7]               0
           Conv2d-91            [-1, 384, 7, 7]          24,576
      BatchNorm2d-92            [-1, 384, 7, 7]             768
            ReLU6-93            [-1, 384, 7, 7]               0
           Conv2d-94            [-1, 384, 7, 7]           3,456
      BatchNorm2d-95            [-1, 384, 7, 7]             768
            ReLU6-96            [-1, 384, 7, 7]               0
           Conv2d-97             [-1, 96, 7, 7]          36,864
      BatchNorm2d-98             [-1, 96, 7, 7]             192
 InvertedResidual-99             [-1, 96, 7, 7]               0
          Conv2d-100            [-1, 576, 7, 7]          55,296
     BatchNorm2d-101            [-1, 576, 7, 7]           1,152
           ReLU6-102            [-1, 576, 7, 7]               0
          Conv2d-103            [-1, 576, 7, 7]           5,184
     BatchNorm2d-104            [-1, 576, 7, 7]           1,152
           ReLU6-105            [-1, 576, 7, 7]               0
          Conv2d-106             [-1, 96, 7, 7]          55,296
     BatchNorm2d-107             [-1, 96, 7, 7]             192
InvertedResidual-108             [-1, 96, 7, 7]               0
          Conv2d-109            [-1, 576, 7, 7]          55,296
     BatchNorm2d-110            [-1, 576, 7, 7]           1,152
           ReLU6-111            [-1, 576, 7, 7]               0
          Conv2d-112            [-1, 576, 7, 7]           5,184
     BatchNorm2d-113            [-1, 576, 7, 7]           1,152
           ReLU6-114            [-1, 576, 7, 7]               0
          Conv2d-115             [-1, 96, 7, 7]          55,296
     BatchNorm2d-116             [-1, 96, 7, 7]             192
InvertedResidual-117             [-1, 96, 7, 7]               0
          Conv2d-118            [-1, 576, 7, 7]          55,296
     BatchNorm2d-119            [-1, 576, 7, 7]           1,152
           ReLU6-120            [-1, 576, 7, 7]               0
          Conv2d-121            [-1, 576, 4, 4]           5,184
     BatchNorm2d-122            [-1, 576, 4, 4]           1,152
           ReLU6-123            [-1, 576, 4, 4]               0
          Conv2d-124            [-1, 160, 4, 4]          92,160
     BatchNorm2d-125            [-1, 160, 4, 4]             320
InvertedResidual-126            [-1, 160, 4, 4]               0
          Conv2d-127            [-1, 960, 4, 4]         153,600
     BatchNorm2d-128            [-1, 960, 4, 4]           1,920
           ReLU6-129            [-1, 960, 4, 4]               0
          Conv2d-130            [-1, 960, 4, 4]           8,640
     BatchNorm2d-131            [-1, 960, 4, 4]           1,920
           ReLU6-132            [-1, 960, 4, 4]               0
          Conv2d-133            [-1, 160, 4, 4]         153,600
     BatchNorm2d-134            [-1, 160, 4, 4]             320
InvertedResidual-135            [-1, 160, 4, 4]               0
          Conv2d-136            [-1, 960, 4, 4]         153,600
     BatchNorm2d-137            [-1, 960, 4, 4]           1,920
           ReLU6-138            [-1, 960, 4, 4]               0
          Conv2d-139            [-1, 960, 4, 4]           8,640
     BatchNorm2d-140            [-1, 960, 4, 4]           1,920
           ReLU6-141            [-1, 960, 4, 4]               0
          Conv2d-142            [-1, 160, 4, 4]         153,600
     BatchNorm2d-143            [-1, 160, 4, 4]             320
InvertedResidual-144            [-1, 160, 4, 4]               0
          Conv2d-145            [-1, 960, 4, 4]         153,600
     BatchNorm2d-146            [-1, 960, 4, 4]           1,920
           ReLU6-147            [-1, 960, 4, 4]               0
          Conv2d-148            [-1, 960, 4, 4]           8,640
     BatchNorm2d-149            [-1, 960, 4, 4]           1,920
           ReLU6-150            [-1, 960, 4, 4]               0
          Conv2d-151            [-1, 320, 4, 4]         307,200
     BatchNorm2d-152            [-1, 320, 4, 4]             640
InvertedResidual-153            [-1, 320, 4, 4]               0
          Conv2d-154           [-1, 1280, 4, 4]         409,600
     BatchNorm2d-155           [-1, 1280, 4, 4]           2,560
           ReLU6-156           [-1, 1280, 4, 4]               0
         Dropout-157                 [-1, 1280]               0
          Linear-158                 [-1, 1000]       1,281,000
================================================================
Total params: 3,504,872
Trainable params: 3,504,872
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.14
Forward/backward pass size (MB): 38.95
Params size (MB): 13.37
Estimated Total Size (MB): 52.47
----------------------------------------------------------------

三、torchsummary源码

  • 之前看到有其他的博主说,这个torchsummary在实际使用的时候,没有很严谨,有些内容没有写进去,然后进行了修改,如果感兴趣,自己看完这个源码可以也进行修改,可以告诉博主,我也学习一下。
import torch
import torch.nn as nn
from torch.autograd import Variable

from collections import OrderedDict
import numpy as np


def summary(model, input_size, batch_size=-1, device="cuda"):

    def register_hook(module):

        def hook(module, input, output):
            class_name = str(module.__class__).split(".")[-1].split("'")[0]
            module_idx = len(summary)

            m_key = "%s-%i" % (class_name, module_idx + 1)
            summary[m_key] = OrderedDict()
            summary[m_key]["input_shape"] = list(input[0].size())
            summary[m_key]["input_shape"][0] = batch_size
            if isinstance(output, (list, tuple)):
                summary[m_key]["output_shape"] = [
                    [-1] + list(o.size())[1:] for o in output
                ]
            else:
                summary[m_key]["output_shape"] = list(output.size())
                summary[m_key]["output_shape"][0] = batch_size

            params = 0
            if hasattr(module, "weight") and hasattr(module.weight, "size"):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
                summary[m_key]["trainable"] = module.weight.requires_grad
            if hasattr(module, "bias") and hasattr(module.bias, "size"):
                params += torch.prod(torch.LongTensor(list(module.bias.size())))
            summary[m_key]["nb_params"] = params

        if (
            not isinstance(module, nn.Sequential)
            and not isinstance(module, nn.ModuleList)
            and not (module == model)
        ):
            hooks.append(module.register_forward_hook(hook))

    device = device.lower()
    assert device in [
        "cuda",
        "cpu",
    ], "Input device is not valid, please specify 'cuda' or 'cpu'"

    if device == "cuda" and torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor

    # multiple inputs to the network
    if isinstance(input_size, tuple):
        input_size = [input_size]

    # batch_size of 2 for batchnorm
    x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
    # print(type(x[0]))

    # create properties
    summary = OrderedDict()
    hooks = []

    # register hook
    model.apply(register_hook)

    # make a forward pass
    # print(x.shape)
    model(*x)

    # remove these hooks
    for h in hooks:
        h.remove()

    print("----------------------------------------------------------------")
    line_new = "{:>20}  {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
    print(line_new)
    print("================================================================")
    total_params = 0
    total_output = 0
    trainable_params = 0
    for layer in summary:
        # input_shape, output_shape, trainable, nb_params
        line_new = "{:>20}  {:>25} {:>15}".format(
            layer,
            str(summary[layer]["output_shape"]),
            "{0:,}".format(summary[layer]["nb_params"]),
        )
        total_params += summary[layer]["nb_params"]
        total_output += np.prod(summary[layer]["output_shape"])
        if "trainable" in summary[layer]:
            if summary[layer]["trainable"] == True:
                trainable_params += summary[layer]["nb_params"]
        print(line_new)

    # assume 4 bytes/number (float on cuda).
    total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
    total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
    total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
    total_size = total_params_size + total_output_size + total_input_size

    print("================================================================")
    print("Total params: {0:,}".format(total_params))
    print("Trainable params: {0:,}".format(trainable_params))
    print("Non-trainable params: {0:,}".format(total_params - trainable_params))
    print("----------------------------------------------------------------")
    print("Input size (MB): %0.2f" % total_input_size)
    print("Forward/backward pass size (MB): %0.2f" % total_output_size)
    print("Params size (MB): %0.2f" % total_params_size)
    print("Estimated Total Size (MB): %0.2f" % total_size)
    print("----------------------------------------------------------------")
    # return summary

  • 8
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值