pytorch中Conv2d的FLOPs的计算范例

3 篇文章 0 订阅
2 篇文章 0 订阅
import torch

conv = torch.nn.Conv2d(1,8,(2,3))
input = torch.rand(1,1,224,224) # batch,channel,width,height
output = conv(input)
print(output.shape)
bn = torch.nn.BatchNorm2d(8)
l = [conv,bn]
for module in l:
    class_name = str(module.__class__.__name__)
    if class_name.find("Conv") != -1 and hasattr(module, "weight"):
        # flops为卷积核的参数量乘以输出特征图的分辨率
        # 即inchannel*outchannel*kernel_width*kernel_height*output_width*output_height
        flops = (
            torch.prod(
                torch.LongTensor(list(module.weight.data.size()))) *
            torch.prod(
                torch.LongTensor(list(output.size())[2:]))).item()
        print(list(module.weight.data.size()))
        print(list(output.size()))
        print(flops)

Conv2d层的计算量FlOPs为:输入通道数×输出通道数×卷积核宽×卷积核高×输出特征图宽×输出特征图高。故此卷积层的计算量为1×8×2×3×223×222=2376288
而BN一般计算量比较小,不算FLOPs

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值