VGG16网络的参数和FLOPs计算
1、前言
VGG16是一种卷积神经网络(Convolutional Neural Network,CNN),是由Simonyan和Zisserman于2014年提出的用于图像识别任务的其中一种网络结构。VGG16网络包含13个卷积层和3个全连接层,在ImageNet数据集上表现出色,并成为后来CNN结构的基础。下面对VGG16网络的参数和FLOPs进行计算。
2、计算
1、模型的参数量
1.1、卷积层计算
模型的参数量=[卷积核的长∗卷积核的宽∗卷积核的高(即通道,由上一层的输出通道决定)]∗卷积核的数量+偏置参数(其等于卷积核的数量)以第一个卷积层为例Conv1_1=kernel(Height)*kernel(Width)InPut(Channel)OutPut(Channel)=(333)*64=1728
1.2、全连接层
模型的参数量=上一层输入的长宽高(即通道)本层的长宽*高(即通道)以第一个全连接层为例:
FC1[(LayerID21,Output(Height))∗(LayerID21,Output(Width))∗(LayerID21,Output(Channel))]∗[(LayerID22,Output(Height))∗(LayerID22,Output(Width))∗(LayerID22,Output(Channel))]=[(7∗7∗512)]∗[1∗1∗4096]=102,760,448
2、模型FLOPs
2.1、卷积层计算
FLOPS数量=参数量∗该层输出特征图的大小该层输出特征图的大小:以第一个卷积层为例FLOPS=OutPut(Height)OutPut(Weight)Params=2242241728=86704128
2.2、全连接层 由于不存在权值共享,它的FLOPs数目即是该层参数数目:
以第一个全连接层为例
FLOPs=Params=102760448
3、总参数量 138,357,544,总FLOPs=15470314496
3、EXCEL详细结果
4、代码
import torch
import torch.nn as nn
from torchvision.models import vgg16
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def count_flops(model, input_size):
flops = 0
input = torch.randn(1, *input_size)
def conv_hook(module, input, output):
nonlocal flops
batch_size, input_channels, input_height, input_width = input[0].size()
output_channels, output_height, output_width = output[0].size()
kernel_height, kernel_width = module.kernel_size
flops += batch_size * output_channels * output_height * output_width * (
input_channels * kernel_height * kernel_width + 1)
def fc_hook(module, input, output):
nonlocal flops
batch_size, input_features = input[0].size()
output_features = output[0].size(0) # 修改这里
flops += batch_size * input_features * output_features
hooks = []
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
hooks.append(module.register_forward_hook(conv_hook))
elif isinstance(module, nn.Linear):
hooks.append(module.register_forward_hook(fc_hook))
model(input)
for hook in hooks:
hook.remove()
return flops
model = vgg16()
params = count_parameters(model)
flops = count_flops(model, (3, 224, 224))
print(f"Parameters: {params}")
print(f"FLOPs: {flops}")
运行结果如下