VGG16网络的参数和FLOPs计算

VGG16网络的参数和FLOPs计算

1、前言

VGG16是一种卷积神经网络(Convolutional Neural Network,CNN),是由Simonyan和Zisserman于2014年提出的用于图像识别任务的其中一种网络结构。VGG16网络包含13个卷积层和3个全连接层,在ImageNet数据集上表现出色,并成为后来CNN结构的基础。下面对VGG16网络的参数和FLOPs进行计算。


在这里插入图片描述

2、计算

1、模型的参数量

1.1、卷积层计算
模型的参数量=[卷积核的长∗卷积核的宽∗卷积核的高(即通道,由上一层的输出通道决定)]∗卷积核的数量+偏置参数(其等于卷积核的数量)以第一个卷积层为例Conv1_1=kernel(Height)*kernel(Width)InPut(Channel)OutPut(Channel)=(333)*64=1728

1.2、全连接层
模型的参数量=上一层输入的长高(即通道)本层的长宽*高(即通道)以第一个全连接层为例:
FC1[(LayerID21,Output(Height))∗(LayerID21,Output(Width))∗(LayerID21,Output(Channel))]∗[(LayerID22,Output(Height))∗(LayerID22,Output(Width))∗(LayerID22,Output(Channel))]=[(7∗7∗512)]∗[1∗1∗4096]=102,760,448

2、模型FLOPs

2.1、卷积层计算
FLOPS数量=参数量∗该层输出特征图的大小该层输出特征图的大小:以第一个卷积层为例FLOPS=OutPut(Height)OutPut(Weight)Params=2242241728=86704128

2.2、全连接层 由于不存在权值共享,它的FLOPs数目即是该层参数数目:
以第一个全连接层为例
FLOPs=Params=102760448

3、总参数量 138,357,544,总FLOPs=15470314496

3、EXCEL详细结果

|  |  |
|--|--|
|  |  |

4、代码

import torch
import torch.nn as nn
from torchvision.models import vgg16

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def count_flops(model, input_size):
    flops = 0
    input = torch.randn(1, *input_size)

    def conv_hook(module, input, output):
        nonlocal flops
        batch_size, input_channels, input_height, input_width = input[0].size()
        output_channels, output_height, output_width = output[0].size()
        kernel_height, kernel_width = module.kernel_size
        flops += batch_size * output_channels * output_height * output_width * (
                    input_channels * kernel_height * kernel_width + 1)

    def fc_hook(module, input, output):
        nonlocal flops
        batch_size, input_features = input[0].size()
        output_features = output[0].size(0)  # 修改这里
        flops += batch_size * input_features * output_features

    hooks = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            hooks.append(module.register_forward_hook(conv_hook))
        elif isinstance(module, nn.Linear):
            hooks.append(module.register_forward_hook(fc_hook))

    model(input)

    for hook in hooks:
        hook.remove()

    return flops

model = vgg16()
params = count_parameters(model)
flops = count_flops(model, (3, 224, 224))

print(f"Parameters: {params}")
print(f"FLOPs: {flops}")

运行结果如下
在这里插入图片描述

  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值