PytorchSummary查看网络摘要

本文对比了预训练的ResNet18模型与自定义CNN模型在参数数量、前向传播大小、总内存占用等方面的差异。ResNet18经过调整输入尺寸后,拥有11,171,779个参数,而自定义CNN模型则包含317,066个参数。两模型均展示了从输入到输出的层结构,揭示了不同卷积层和全连接层对模型复杂度的影响。
摘要由CSDN通过智能技术生成
import torch
from torchsummary import summary
import torchvision

model = torchvision.models.resnet18(pretrained=True)
model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = torch.nn.Linear(in_features=512, out_features=3)

# # 需要使用device来指定网络在GPU还是CPU运行
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# netG_A2B = model().to(device)
summary(model, input_size=(1, 299, 299))

输出:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 150, 150]           3,136
       BatchNorm2d-2         [-1, 64, 150, 150]             128
              ReLU-3         [-1, 64, 150, 150]               0
         MaxPool2d-4           [-1, 64, 75, 75]               0
            Conv2d-5           [-1, 64, 75, 75]          36,864
       BatchNorm2d-6           [-1, 64, 75, 75]             128
              ReLU-7           [-1, 64, 75, 75]               0
            Conv2d-8           [-1, 64, 75, 75]          36,864
       BatchNorm2d-9           [-1, 64, 75, 75]             128
             ReLU-10           [-1, 64, 75, 75]               0
       BasicBlock-11           [-1, 64, 75, 75]               0
           Conv2d-12           [-1, 64, 75, 75]          36,864
      BatchNorm2d-13           [-1, 64, 75, 75]             128
             ReLU-14           [-1, 64, 75, 75]               0
           Conv2d-15           [-1, 64, 75, 75]          36,864
      BatchNorm2d-16           [-1, 64, 75, 75]             128
             ReLU-17           [-1, 64, 75, 75]               0
       BasicBlock-18           [-1, 64, 75, 75]               0
           Conv2d-19          [-1, 128, 38, 38]          73,728
      BatchNorm2d-20          [-1, 128, 38, 38]             256
             ReLU-21          [-1, 128, 38, 38]               0
           Conv2d-22          [-1, 128, 38, 38]         147,456
      BatchNorm2d-23          [-1, 128, 38, 38]             256
           Conv2d-24          [-1, 128, 38, 38]           8,192
      BatchNorm2d-25          [-1, 128, 38, 38]             256
             ReLU-26          [-1, 128, 38, 38]               0
       BasicBlock-27          [-1, 128, 38, 38]               0
           Conv2d-28          [-1, 128, 38, 38]         147,456
      BatchNorm2d-29          [-1, 128, 38, 38]             256
             ReLU-30          [-1, 128, 38, 38]               0
           Conv2d-31          [-1, 128, 38, 38]         147,456
      BatchNorm2d-32          [-1, 128, 38, 38]             256
             ReLU-33          [-1, 128, 38, 38]               0
       BasicBlock-34          [-1, 128, 38, 38]               0
           Conv2d-35          [-1, 256, 19, 19]         294,912
      BatchNorm2d-36          [-1, 256, 19, 19]             512
             ReLU-37          [-1, 256, 19, 19]               0
           Conv2d-38          [-1, 256, 19, 19]         589,824
      BatchNorm2d-39          [-1, 256, 19, 19]             512
           Conv2d-40          [-1, 256, 19, 19]          32,768
      BatchNorm2d-41          [-1, 256, 19, 19]             512
             ReLU-42          [-1, 256, 19, 19]               0
       BasicBlock-43          [-1, 256, 19, 19]               0
           Conv2d-44          [-1, 256, 19, 19]         589,824
      BatchNorm2d-45          [-1, 256, 19, 19]             512
             ReLU-46          [-1, 256, 19, 19]               0
           Conv2d-47          [-1, 256, 19, 19]         589,824
      BatchNorm2d-48          [-1, 256, 19, 19]             512
             ReLU-49          [-1, 256, 19, 19]               0
       BasicBlock-50          [-1, 256, 19, 19]               0
           Conv2d-51          [-1, 512, 10, 10]       1,179,648
      BatchNorm2d-52          [-1, 512, 10, 10]           1,024
             ReLU-53          [-1, 512, 10, 10]               0
           Conv2d-54          [-1, 512, 10, 10]       2,359,296
      BatchNorm2d-55          [-1, 512, 10, 10]           1,024
           Conv2d-56          [-1, 512, 10, 10]         131,072
      BatchNorm2d-57          [-1, 512, 10, 10]           1,024
             ReLU-58          [-1, 512, 10, 10]               0
       BasicBlock-59          [-1, 512, 10, 10]               0
           Conv2d-60          [-1, 512, 10, 10]       2,359,296
      BatchNorm2d-61          [-1, 512, 10, 10]           1,024
             ReLU-62          [-1, 512, 10, 10]               0
           Conv2d-63          [-1, 512, 10, 10]       2,359,296
      BatchNorm2d-64          [-1, 512, 10, 10]           1,024
             ReLU-65          [-1, 512, 10, 10]               0
       BasicBlock-66          [-1, 512, 10, 10]               0
AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0
           Linear-68                    [-1, 3]           1,539
================================================================
Total params: 11,171,779
Trainable params: 11,171,779
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.34
Forward/backward pass size (MB): 114.26
Params size (MB): 42.62
Estimated Total Size (MB): 157.21
----------------------------------------------------------------

或者

import torch
import torch.nn as nn
import math
from torchsummary import summary

class CNN(nn.Module):
    def __init__(self, num_channel, num_classes, num_pixel):
        super().__init__()
        self.conv1 = nn.Conv2d(
            num_channel, 32, kernel_size=5, padding=0, stride=1, bias=True
        )
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=0, stride=1, bias=True)
        self.maxpool = nn.MaxPool2d(kernel_size=(2, 2))
        self.act = nn.ReLU(inplace=True)

        ###
        ### X_out = floor{ 1 + (X_in + 2*padding - dilation*(kernel_size-1) - 1)/stride }
        ###
        X = num_pixel
        X = math.floor(1 + (X + 2 * 0 - 1 * (5 - 1) - 1) / 1)
        X = X / 2
        X = math.floor(1 + (X + 2 * 0 - 1 * (5 - 1) - 1) / 1)
        X = X / 2
        X = int(X)
        # X = 147
        print(X)

        self.fc1 = nn.Linear(64 * X * X, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.maxpool(x)
        x = self.act(self.conv2(x))
        x = self.maxpool(x)
        x = torch.flatten(x, 1)
        x = self.act(self.fc1(x))
        x = self.fc2(x)
        return x

cnn = CNN(1,10,28)
summary(cnn,input_size=(1, 28, 28))

输出:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 24, 24]             832
              ReLU-2           [-1, 32, 24, 24]               0
         MaxPool2d-3           [-1, 32, 12, 12]               0
            Conv2d-4             [-1, 64, 8, 8]          51,264
              ReLU-5             [-1, 64, 8, 8]               0
         MaxPool2d-6             [-1, 64, 4, 4]               0
            Linear-7                  [-1, 256]         262,400
              ReLU-8                  [-1, 256]               0
            Linear-9                   [-1, 10]           2,570
================================================================
Total params: 317,066
Trainable params: 317,066
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.39
Params size (MB): 1.21
Estimated Total Size (MB): 1.60
----------------------------------------------------------------
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值