Pytorch-nn.Conv2d中groups参数的理解

Pytorch-nn.Conv2d中groups参数的理解

class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

group参数的作用是控制分组卷积。

直接通过实际的例子加以说明。

import torch
import torch.nn as nn

x = torch.Tensor([1, 10, 100, 1000, 10000, 100000]).view(1, -1, 1, 1)
print("x:", x.int())
conv = nn.Conv2d(
    in_channels=6, out_channels=9, kernel_size=1, stride=1, padding=0, groups=3, bias=False
)
print("Conv weight size:", conv.weight.data.size())
conv.weight.data = torch.arange(1, 19).float().view(9, 2, 1, 1)
print("Conv weight data:", conv.weight.data.int())
output = conv(x).int()
print("Output:", output)

如果是正常的卷积,参数大小应该为: [9(输出通道), 6(输入通道), 1(核h), 1(核w)]。
这是因为输出是9个通道,每个通道都需要一个[6, 1, 1]大小的卷积(输入的每个通道都参与到了运算)。
但是我们可以从代码的运行结果中看到Conv层的参数大小为: [9, 2, 1, 1]。这就说明对于每个输出的通道,只有两个输入的通道参与了运算。
事实就是这样,分组卷积的过程中只有部分输入的通道才参与了运算。我们就以上面的代码为例进行讲解。

  • 首先将输入的6个通道分为3组: [1, 10], [100, 1000], [10000, 100000],每一组都用来生成输出的一个通道。
  • 3个组只能生成3个输出通道,但是要求输出是9个通道,所以每个组需要重复计算三次。
  • 输出的第1个通道: 1 * 1 + 2 * 10 = 21,需要用到输入的第1组。
    输出的第2个通道: 3 * 1 + 4 * 10 = 43,需要用到输入的第1组。
    输出的第3个通道: 5 * 1 + 6 * 10 = 65,需要用到输入的第1组。
    输出的第4个通道: 7 * 100 + 8 * 1000 = 8700,需要用到输入的第2组。

根据上面的分析,我们可以得出结论: group参数必须整除输入的通道数(保证输入的通道能被正确分组),
还必须整除输出的通道数(保证group个分组重复若干次之后恰好等于输出通道数)。

  • 7
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值