关于ShuffleNetV1中的channel shuffle操作【代码分析】

1. 官方给出的代码 

旷视科技在自己的开源GitHub上给出的channel shuffle相关代码如下图所示:

      分析上图中的代码,旷视科技将channel shuffle这个操作视为一个函数,函数传入的参数是输入张量x,x的shape为(batchsize, num_channels, height, width)。
      首先对输入张量x使用data.size() 方法进行解包,从输入张量 x 中提取批量大小、通道数、高度和宽度。
      使用assert函数检查 group【分组个数】是否能够整除num_channels,若不能够整除,则函数运行到此处抛出AssertionError 异常;若能够整除,则正常运行。
      每个小组的通道数group_channels 为总体通道数num_channels除分组个数group。
接下来的三行代码均对x操作,我们一步一步来剖析:
首先经过:

x = x.reshape(batchsize, group_channels, self.group, height, width)

      这表示x需要经过reshape操作,将num_channels分为group个组,每个组中的通道数为group_channels。

在经过如下操作:

x = x.permute(0, 2, 1, 3, 4)

      这表示要将x的第1个维度与第2个维度进行互换,也就是说,可以理解为在这里对x经历了转置操作。

  • 重新排列维度,使得维度的顺序变为:
    • 维度 0:批量大小保持不变。
    • 维度 2:将组的维度移到第二位。
    • 维度 1:将每组的通道维度移到后面。
    • 维度 3 和 4:高度和宽度保持不变。
  • 这一步的作用是将不同组的通道位置互换,从而实现通道间的信息交互。

然后再经过如下操作:

x = x.reshape(batchsize, num_channels, height, width)

将重排后的张量重塑回 (batchsize, num_channels, height, width) 原始形状。
最后借助 return x 返回channel shuffle后的张量。

总结 

该方法实现了channel shuffle的过程,通过将通道分组、重排和恢复形状来增强通道间的信息交互,通常用于提升轻量级网络的性能。channel shuffle有助于使模型更好地利用特征共享,提高整体表现。 

 2. 喂入测试张量进行测试【图例分析】

假设输入张量的shape为:(1, 12, 1, 1)  group=3
首先通过以下代码构建输入张量,使用unsqueeze函数是为了给一维张量进行扩维,使之符合输入张量的shape。

对官方代码小修小改,得到独立可运行的channel_shuffle函数,如下图所示:

以图说明上述代码:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值