转载Unet网络参数的计算

转自:https://bbs.cvmart.net/articles/1632

  1. 理论公式部分
    卷积层:在这里插入图片描述

其中 K 为卷积核大小,Ci为输入channel数,Co为输出的channel数(也是filter的数量),算式第二项是偏置项的参数量 。(虽然一般不写偏置项,因为不会影响总参数量的数量级,但是我们为了准确起见,把偏置项的参数量也考虑进来)

BN层:在这里插入图片描述
,其中Ci为输入的channel数
(BN层有两个需要学习的参数,平移因子和缩放因子)

全连接层:
,Ti为输入向量的长度,To为输出向量的长度,其中第二项为偏置项参数量。 (不过目前全连接层已经逐渐被Global Average Pooling层取代了)

  1. 实践部分
    我们首先摆出经典的UNet结构图。
    在这里插入图片描述

我们把UNet共分为5个Stage,分别计算每个stage的参数量。每个stage的filter数量为

[32,64,128,256,512],相比于UNet原文,我们把UNet的channel数缩小了两倍,大多数论文也的确是这么做的。同时,我们设置UNet上采样方式为TransposeConv(转置卷积),并在每个[公式]Conv后加入BN层。最后假定,原始输入channel为1,输出分割图为两类(含背景),这样最终得到我们要计算参数量的UNet。

这样定义的UNet主要有四个组件,3×3Conv,1×1Conv,TransposeConv和BN层。

我们先计算Conv,再计算BN层。

Stage1:
在这里插入图片描述
Stage2:
在这里插入图片描述
Stage3:


Stage4:
在这里插入图片描述
Stage5:
在这里插入图片描述
TransposeConv:

[公式]
目前为止,我们把以上所有Conv得到的参数量求和 ,得到了没有加BN的UNet参数量。
在这里插入图片描述
接下来,我们计算BN层的参数量,易得:

[公式]
和刚求得的参数量求和
在这里插入图片描述
加入了BN的UNet参数量为7765442。

我们得到的两种UNet的参数量,如何验证其正确性呢?或者说我们不可能每次都手动计算这些网络的参数量,这就需要脚本去帮我们去计算网络的参数量。代码如下:

def count_param(model):
    param_count = 0
    for param in model.parameters():
        param_count += param.view(-1).size()[0]
    return param_count

我们通过项目【UNet-family】中提供的脚本,可以直接使用命令python UNet.py来计算网络的参数量。通过设置参数is_batchnorm我们得到两次计算结果如下:

有BN:
file

没有BN:
file

赶快看一眼我们之前算的参数量,一模一样,说明我们算的没毛病。

有没有一种再也不怕面试官问自己参数量计算的感觉。

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值