先要理清楚5D数据的形变问题,因为我们需要把数据降维,先来个简单的3D数据的降维再恢复原状:
x = torch.arange(24).view(2,3,4)
print(x)
x = x.transpose(0,2)
x = x.transpose(0,1)
x = x.reshape(-1)
print(x)
x = x.reshape(4,3,2)
x = x.transpose(2,0)
x = x.transpose(1,2)
x = x.reshape(2,3,4)
print(x)
输出:
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
tensor([ 0, 12, 1, 13, 2, 14, 3, 15, 4, 16, 5, 17, 6, 18, 7, 19, 8, 20,
9, 21, 10, 22, 11, 23])
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
再来个5D数据的降维和恢复原状:
x = torch.arange(720).view(5,6,2,3,4)
print(x)
x = x.transpose(2,4)
x = x.transpose(2,3)
x = x.reshape(5,6,24)
print(x)
x = x.reshape(5,6,4,3,2)
x = x.transpose(3,4)
x = x.transpose(2,3)
x = x.reshape(5,6,2,3,4)
print(x)
数据比较多,就不打印出来呢,反正最后保证降维后升维还能够恢复原来数据顺序。
接下来就是关键代码了:
xs = torch.randn(8,3, 16, 32, 32)
conv3d = nn.Conv3d(3, 32, bias=False, kernel_size=3,padding=1)
conv3d_result = conv3d(xs)
print(conv3d_result.shape)
conv2d = nn.Conv2d(3, 32, bias=False, kernel_size=3,padding=1)
conv3x3_result = []
for x in zip(xs.unbind(dim=2)):
x = conv2d(x[0])
conv3x3_result.append(x)
conv3x3_result = torch.stack(conv3x3_result, dim=2)
b,c,d,h,w = conv3x3_result.shape
conv1d = nn.Conv1d(32, 32, bias=False, kernel_size=3,padding=1)
x = conv3x3_result.transpose(2,4)
x = x.transpose(2,3)
x = x.reshape(b,c,d*h*w)
x = conv1d(x)
x = x.reshape(b,c,w,h,d)
x = x.transpose(3,4)
x = x.transpose(2,3)
x = x.reshape(b,c,d,h,w)
这样就把5维数据3D卷积来处理转换成了2D和1D卷积处理了,目前来看理论上应该没问题,具体是否正确还需要验证。