Pytorch中nn.Conv2d数据计算模拟

Pytorch中nn.Conv2d数据计算模拟

最近在研究dgcnn网络的源码,其网络架构部分使用的是nn.Conv2d模块。在Pytorch的官方文档中,nn.Conv2d的输入数据为(B, Cin, W, H) 其中B为batch_size表示batch的大小,Cin为输入数据的特征大小(通道数),W、H对于图像数据来说分别表示图像数据的宽和高。输出数据为(B, Cout, W', H')其中Cout表示输出的特征大小,W’, H’取决于W, H,具体转换方式如下图所示:
Conv2d输入输出
通过查询nn.Conv2d的源码可知,nn.Conv2d底层是由nn.functional.Conv2d实现的,所以可以使用nn.functional.Conv2d模拟nn.Conv2d操作。

# nn.Conv2d源码
    def conv2d_forward(self, input, weight):
        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
                                (self.padding[0] + 1) // 2, self.padding[0] // 2)
            return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
                            weight, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def forward(self, input):
        return self.conv2d_forward(input, self.weight)

根据论文,第一层EdgeConv使用的stride=1,bias=None,因此使用一下代码模拟:

import torch
import torch.nn as nn
import torch.nn.functional as F
data = torch.randint(0,10,(2,3,5,4)).float()  # 输入数据 B=2, C=3, 点云中点的个数num_points=5, Knn选取的邻近点数K
w = torch.randint(0,10,(4,3,1,1)).float()  # 权重 文档显示权重的shape为(output_channel, input_channel/group, kenerl_size[0], kenerl_size[1]), 这里group去默认值1, kenerl_size为(1, 1)
b = torch.Tensor([0,0,0,0])  # bias=None
out = F.conv2d(data,w,b)

得到的结果(data、out取第一个batch)为:

data = [[[[2., 9., 6., 0.],
                  [2., 8., 8., 2.],
                  [0., 2., 8., 0.],
                  [9., 5., 8., 7.],
                  [6., 7., 6., 5.]],

                [[6., 6., 9., 2.],
                 [7., 0., 6., 5.],
                 [7., 7., 1., 1.],
                 [7., 0., 7., 3.],
                 [3., 8., 0., 6.]],

                [[9., 2., 5., 2.],
                  [0., 1., 6., 9.],
                  [3., 3., 6., 0.],
                  [9., 8., 2., 0.],
                  [6., 7., 8., 9.]]]

 w =  [[[[1.]],
             [[6.]],
             [[0.]]],

            [[[4.]],
             [[7.]],
             [[1.]]],

           [[[0.]],
            [[2.]],
            [[0.]]],

          [[[7.]],
           [[6.]],
           [[8.]]]])

out = [[[ 38.,  45.,  60.,  12.],
               [ 44.,   8.,  44.,  32.],
               [ 42.,  44.,  14.,   6.],
               [ 51.,   5.,  50.,  25.],
               [ 24.,  55.,   6.,  41.]],

             [[ 59.,  80.,  92.,  16.],
              [ 57.,  33.,  80.,  52.],
              [ 52.,  60.,  45.,   7.],
              [ 94.,  28.,  83.,  49.],
              [ 51.,  91.,  32.,  71.]],

            [[ 12.,  12.,  18.,   4.],
             [ 14.,   0.,  12.,  10.],
             [ 14.,  14.,   2.,   2.],
             [ 14.,   0.,  14.,   6.],
             [  6.,  16.,   0.,  12.]],

          [[122., 115., 136.,  28.],
           [ 56.,  64., 140., 116.],
           [ 66.,  80., 110.,   6.],
           [177.,  99., 114.,  67.],
          [108., 153., 106., 143.]]]

其中data[ :, 0, 0]表示的是点云中第0个点的第0个邻近点的3个通道[2, 6, 9], data[ :, 0, 1]=[9, 6, 2]以此类推为第0个点的第1个邻近点的3个通道[9, 6, 2], out中的第0个元素out[0,0,0]=38其实是由data[ :, 0, 0]与w[0, :, :] = [1, 6, 0]点乘得到的: 38 = 1 X 2 + 6 X 6 + 0 X 9, 同理, out[0, 0, 1]=45=1 X 9 + 6 X 6 + 0 X 2, out[0, 1, 0]=59=w[1, :, :] dot data[:, 0, 0] = [4, 7, 1] dot [2, 6, 9] = 4 X 2 + 7 X 6 + 1 X 9 = 59。

以上。

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值