PyTorch中CNN网络参数计算和模型文件大小预估

前言

在深度学习CNN构建过程中,网络的参数量是一个需要考虑的问题。太深的网络或是太大的卷积核、太多的特征图通道数都会导致网络参数量上升。写出的模型文件也会很大。所以提前计算网络参数和预估模型文件大小很重要。
在这里插入图片描述

网络参数计算

先定义好网络结构,然后统计网络参数。

网络定义

以LeNet-5为例,参考之前的博客《PyTorch构建网络示例:LeNet-5》,网络结构设计代码如下:

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import os

class lenet5(nn.Module):
    def __init__(self):
        super(lenet5,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)
        self.conv2 = nn.Conv2d(6,16,5,1)
        self.pool2 = nn.AvgPool2d(2)
        self.fc1 = nn.Linear(4*4*16,120)#注:按原始的minst数据集,输入为图示中的32*32时,此处应该是5*5*16.但是按torchvision.datasets中的输入大小则是28*28,此处为4*4*16。或者直接在Data.DataLoader时将输入transforms到32*32大小
        self.fc2 = nn.Linear(120,84)
        self.out = nn.Linear(84,10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)#flatten the output of pool2 to (batch_size, 16 * 4 * 4),x.shape[0]为batch_size,-1为自适应调整大小
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.out(x),dim=-1)
        return x

net = lenet5()

参数计算

定义参数统计函数并传入实例化的net:

def cnn_paras_count(net):
    """cnn参数量统计, 使用方式cnn_paras_count(net)"""
    # Find total parameters and trainable parameters
    total_params = sum(p.numel() for p in net.parameters())
    print(f'{total_params:,} total parameters.')
    total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print(f'{total_trainable_params:,} training parameters.')
    return total_params, total_trainable_params

cnn_paras_count(net)

输出结果:
在这里插入图片描述

模型文件大小预估

需要注意的是,一般模型中参数是以float32保存的,也就是一个参数由4个bytes表示,那么就可以将参数量转化为存储大小。
例如:

44426个参数*4 / 1024 ≈ 174KB

查看网络训练中实际存储的模型文件大小:
在这里插入图片描述
相差的部分是模型文件中除了需要存储实际参数外还需要存储一些网络信息。总体差异不大。
更进一步可以转化为MBGB等单位。

后记

CNN网络参数计算和模型文件大小预估对实际的网络设计很有帮助,可以提前了解所设计的网络的复杂度,并根据模型文件大小提前规划硬盘存储空间。

在这里插入图片描述

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

TracelessLe

❀点个赞加个关注再走吧❀

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值