Pytorch中nn.Conv2d数据计算模拟
最近在研究dgcnn网络的源码,其网络架构部分使用的是nn.Conv2d模块。在Pytorch的官方文档中,nn.Conv2d的输入数据为(B, Cin, W, H)
其中B为batch_size表示batch的大小,Cin为输入数据的特征大小(通道数),W、H对于图像数据来说分别表示图像数据的宽和高。输出数据为(B, Cout, W', H')
其中Cout表示输出的特征大小,W’, H’取决于W, H,具体转换方式如下图所示:
通过查询nn.Conv2d的源码可知,nn.Conv2d底层是由nn.functional.Conv2d实现的,所以可以使用nn.functional.Conv2d模拟nn.Conv2d操作。
# nn.Conv2d源码
def conv2d_forward(self, input, weight):
if self.padding_mode == 'circular':
expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
(self.padding[0] + 1) // 2, self.padding[0] // 2)
return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
weight, self.bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, input):
return self.conv2d_forward(input, self.weight)
根据论文,第一层EdgeConv使用的stride=1,bias=None,因此使用一下代码模拟:
import torch
import torch.nn as nn
import torch.nn.functional as F
data = torch.randint(0,10,(2,3,5,4)).float() # 输入数据 B=2, C=3, 点云中点的个数num_points=5, Knn选取的邻近点数K
w = torch.randint(0,10,(4,3,1,1)).float() # 权重 文档显示权重的shape为(output_channel, input_channel/group, kenerl_size[0], kenerl_size[1]), 这里group去默认值1, kenerl_size为(1, 1)
b = torch.Tensor([0,0,0,0]) # bias=None
out = F.conv2d(data,w,b)
得到的结果(data、out取第一个batch)为:
data = [[[[2., 9., 6., 0.],
[2., 8., 8., 2.],
[0., 2., 8., 0.],
[9., 5., 8., 7.],
[6., 7., 6., 5.]],
[[6., 6., 9., 2.],
[7., 0., 6., 5.],
[7., 7., 1., 1.],
[7., 0., 7., 3.],
[3., 8., 0., 6.]],
[[9., 2., 5., 2.],
[0., 1., 6., 9.],
[3., 3., 6., 0.],
[9., 8., 2., 0.],
[6., 7., 8., 9.]]]
w = [[[[1.]],
[[6.]],
[[0.]]],
[[[4.]],
[[7.]],
[[1.]]],
[[[0.]],
[[2.]],
[[0.]]],
[[[7.]],
[[6.]],
[[8.]]]])
out = [[[ 38., 45., 60., 12.],
[ 44., 8., 44., 32.],
[ 42., 44., 14., 6.],
[ 51., 5., 50., 25.],
[ 24., 55., 6., 41.]],
[[ 59., 80., 92., 16.],
[ 57., 33., 80., 52.],
[ 52., 60., 45., 7.],
[ 94., 28., 83., 49.],
[ 51., 91., 32., 71.]],
[[ 12., 12., 18., 4.],
[ 14., 0., 12., 10.],
[ 14., 14., 2., 2.],
[ 14., 0., 14., 6.],
[ 6., 16., 0., 12.]],
[[122., 115., 136., 28.],
[ 56., 64., 140., 116.],
[ 66., 80., 110., 6.],
[177., 99., 114., 67.],
[108., 153., 106., 143.]]]
其中data[ :, 0, 0]表示的是点云中第0个点的第0个邻近点的3个通道[2, 6, 9], data[ :, 0, 1]=[9, 6, 2]以此类推为第0个点的第1个邻近点的3个通道[9, 6, 2], out中的第0个元素out[0,0,0]=38其实是由data[ :, 0, 0]与w[0, :, :] = [1, 6, 0]点乘得到的: 38 = 1 X 2 + 6 X 6 + 0 X 9, 同理, out[0, 0, 1]=45=1 X 9 + 6 X 6 + 0 X 2, out[0, 1, 0]=59=w[1, :, :] dot data[:, 0, 0] = [4, 7, 1] dot [2, 6, 9] = 4 X 2 + 7 X 6 + 1 X 9 = 59。
以上。