【Pytorch】一文搞懂nn.Conv2d的groups参数的作用

本文详细解释了Pytorch中nn.Conv2d模块的groups参数,通过语言描述和代码实例展示了groups=1和groups=2时卷积操作的区别。当groups=2时,输入和输出通道被分组,每组独立进行卷积,然后将结果concat。代码验证部分展示了不同groups设置下卷积输出特征图的变化,帮助读者直观理解分组卷积的工作原理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1. 语言描述

在Pytorch1.13的官方文档中,关于nn.Conv2d中的groups的作用是这么描述的:
在这里插入图片描述
简单来说就是将输入和输出的通道(channel)进行分组,每一组单独进行卷积操作,然后再把结果拼接(concat)起来。

比如输入大小为 ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5),输出大小为 ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5) g r o u p s = 2 groups=2 groups=2。就是将输入的4个channel分成2个2的channel,输出的8个channel分成2个4的channel,每个输入的2个channel和输出的4个channel组成一组,每组做完卷积后的输出大小为 ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5)。然后把得到的两组输出在channel这个维度上进行concat,得到最后的输出维度为 ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5)

但其实这么描述理解起来不够直观,下面我举个例子,先从语言上进行详细的解释,然后再进行代码验证。

符号数值含义
i n p u t _ c h a n n e l input\_channel input_channel4输入通道数量
o n p u t _ c h a n n e l onput\_channel onput_channel8输出通道数量,其实就是卷积核的个数,我们将其看作卷积核的个数会更容易理解
b a t c h _ s i z e batch\_size batch_size1批量大小为1
H , W H, W H,W5输入输出的feature大小为5x5
i n p u t _ s h a p e input\_shape input_shape ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5)输入的shape,注意我们这里设置输入的所有元素都为1,即输入是一个全1的tensor
o u t p u t _ s h a p e output\_shape output_shape ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5)输出的shape
k e r n e l _ s i z e kernel\_size kernel_size3卷积核的大小为3x3
p a d d i n g padding padding1填充长度为1,这里我们使用1填充(即周围补一圈1),而不是0填充
s t r i d e stride stride1步长为1

我们假设输入tensor的shape为 ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5)输出tensor的shape为: ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5),即我们的卷积核有8个。下面的图由于 b a t c h _ s i z e = 1 batch\_size=1 batch_size=1,所以省略的 b a t c h _ s i z e batch\_size batch_size的维度。
在这里插入图片描述

值得注意的是,这里我们手动设置卷积核中元素的值,前4个卷积核的值都设置为1,后4个卷积核的值都设置为2,如下图所示:

在这里插入图片描述
这里解释一下为什么 g r o u p s = 1 groups=1 groups=1 k e r n e l _ s i z e = ( 4 , 3 , 3 ) kernel\_size=(4, 3, 3) kernel_size=(4,3,3) g r o u p s = 2 groups=2 groups=2 k e r n e l _ s i z e = ( 2 , 3 , 3 ) kernel\_size=(2, 3, 3) kernel_size=(2,3,3):因为 g r o u p s = 2 groups=2 groups=2时,输入和输出都被分成了两组,输入的shape原来为: ( 4 , 5 , 5 ) (4, 5, 5) (4,5,5),被分成了两个 ( 2 , 5 , 5 ) (2, 5, 5) (2,5,5),所以每个 k e r n e l _ s i z e kernel\_size kernel_size也由 ( 4 , 3 , 3 ) (4, 3, 3) (4,3,3)变为 ( 2 , 3 , 3 ) (2, 3, 3) (2,3,3)

下面我们来看一下 g r o u p s = 1 groups=1 groups=1 g r o u p s = 2 groups=2 groups=2时计算过程的不同:

【情况1:groups=1】
此时就和正常卷积一样:
在这里插入图片描述

这里解释一下:output的前4个channel的每个feature map的所有元素都为36,后4个channel的每个feature map的所有元素都为72,这是因为:
每个输入的 H , W H,W H,W是5x5,加上padding之后是6x6,具体过程如下:
在这里插入图片描述

【情况1:groups=2】
此时应当这么算:
在这里插入图片描述
为什么output的前4个channel的每个feature map的所有元素都为18,后4个channel的每个feature map的所有元素都为36呢?看了下面的图应该就能理解这个过程了:
在这里插入图片描述

2. 代码验证:

实验环境:Python3.7,torch1.10.2
代码:

import os

import torch
import torch.nn as nn


if __name__ == '__main__':
    input_dim, output_dim = 4, 8
    X = torch.ones(1, input_dim, 5, 5)

    # groups = 1
    conv1 = nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1, groups=1, bias=False, padding_mode='replicate')
    print(f'groups=1时,卷积核的形状为:{conv1.weight.shape}')
    with torch.no_grad():
        conv1.weight[:4, :, :, :] = torch.ones(4, 4, 3, 3)
        conv1.weight[4:, :, :, :] = torch.ones(4, 4, 3, 3) * 2
        Y1 = conv1(X)
        print(f'结果为:\n{Y1}')

    # groups = 2
    conv2 = nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1, groups=2, bias=False, padding_mode='replicate')
    print(f'groups=2时,卷积核的形状为:{conv2.weight.shape}')
    with torch.no_grad():
        conv2.weight[:4, :, :, :] = torch.ones(4, 2, 3, 3)
        conv2.weight[4:, :, :, :] = torch.ones(4, 2, 3, 3) * 2
        Y2 = conv2(X)
        print(f'结果为:\n{Y2}')


结果:

groups=1时,卷积核的形状为:torch.Size([8, 4, 3, 3])
结果为:
tensor([[[[36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.]],

         [[36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.]],

         [[36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.]],

         [[36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.]],

         [[72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.]],

         [[72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.]],

         [[72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.]],

         [[72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.]]]])
groups=2时,卷积核的形状为:torch.Size([8, 2, 3, 3])
结果为:
tensor([[[[18., 18., 18., 18., 18.],
          [18., 18., 18., 18., 18.],
          [18., 18., 18., 18., 18.],
          [18., 18., 18., 18., 18.],
          [18., 18., 18., 18., 18.]],

         [[18., 18., 18., 18., 18.],
          [18., 18., 18., 18., 18.],
          [18., 18., 18., 18., 18.],
          [18., 18., 18., 18., 18.],
          [18., 18., 18., 18., 18.]],

         [[18., 18., 18., 18., 18.],
          [18., 18., 18., 18., 18.],
          [18., 18., 18., 18., 18.],
          [18., 18., 18., 18., 18.],
          [18., 18., 18., 18., 18.]],

         [[18., 18., 18., 18., 18.],
          [18., 18., 18., 18., 18.],
          [18., 18., 18., 18., 18.],
          [18., 18., 18., 18., 18.],
          [18., 18., 18., 18., 18.]],

         [[36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.]],

         [[36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.]],

         [[36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.]],

         [[36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.]]]])

Process finished with exit code 0

整体流程我手画了个图,我感觉比PPT画的还清楚,可以更好地理解过程在这里插入图片描述

END:)

p.s.:没想到写个博客写了一上午,画图太费时间了!本来上午还有别的事情的。。。只能推到下午再做了0.0
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

SinHao22

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值