PixelShuffle与张量的维度重塑操作
从矩阵转置到高维张量的维度交换
在阅读论文代码时,对于PixelShuffle的实现产生了一些困惑。在网上的一些代码中,使用简单的reshape以及相应的transpose(对于numpy)或permute(对于pytorch)函数就可以实现pixelshuffle的操作,而对于高维张量的维度间的交换结果始终没能有一个理解方式,来确保维度交换后的数据是正确的、自己想要的数据,因此这里以numpy为例,从二维矩阵的转置出发,理解高维张量的维度交换。记录一篇博客,供自己学习与记忆
关于PixelShuffle
PixelShuffle是low-level任务中常用的上采样操作,关于此的详细介绍可以参考PixelShuffle.
用numpy实现的pixelshuffle过程如下所示:
import numpy as np
a = np.arange(36).reshape([4, 3, 3]) # array(4, 3, 3) 对应于(C,H,W)
b = a.reshape([2, 2, 3, 3]) # array(2, 2, 3, 3)
c = b.transpose([2, 0, 3, 1]) # array(3, 2, 3, 3)
d = c.reshape([6, 6]) # upsampled array(6, 6) 即(2*H, 2*W)
而为什么上述维度重塑操作就可以实现如PixelShuffle.中所解释的上采样后的效果呢,我们如何确定经过维度交换并形状重整后的张量数据符合我们的预期?
从Array的base、strides、address属性理解矩阵转置
- ndarray的base属性是指当前array由哪一个array变换得来,举个例子
a = np.arange(36).reshape((4, 3, 3))
print(f"base of a:{a.base}\n")
b = a.reshape((2, 2, 3, 3))
print(f"base of b:{b.base}\n")
c = b.transpose((2, 0, 3, 1))
print(f"base of c:{c.base}\n")
# output:
"""
base of a:[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
24 25 26 27 28 29 30 31 32 33 34 35]
base of b:[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
24 25 26 27 28 29 30 31 32 33 34 35]
base of c:[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
24 25 26 27 28 29 30 31 32 33 34 35]
"""
可以看到,a,b,c均是由最原始的一维数组变换而来。如果一个数组变量是经过多次变换得到的,那么base属性返回的是最初始的base数组。
2. ndarray的stride属性代表每一个维度取值时,在原始的连续内存空间上每隔多少个元素取一个值,这里,原始的连续内存空间即是指a.base代表的数组:
print(f"stride of a:{a.strides}\n")
print(f"stride of b:{b.strides}\n")
"""
outputs:
stride of a:(36, 12, 4) # integer, 4bytes one element
stride of b:(72, 36, 12, 4)
# 其实转换为真实元素个数的话,
# a:(9, 3, 1)
# b:(18, 9, 3, 1)
"""
对于a来说,形状是(4,3,3),strides[2]为1,也就是说a的第2维在原始空间上每隔1个元素取值,第1维每隔3个元素取一个值进行填充,第0维每隔9个元素取值进行填充(这里填充只是形象化表达,实质上这些reshape操作不会改变原始内存空间,属于浅拷贝);b则以此类推:
print(f"a:{a}\n")
print(f"address of a:{a.__array_interface__['data']}\n")
print(f"b:{b}\n")
print(f"address of b:{b.__array_interface__['data']}\n")
"""
outputs:
a:[[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]]
[[ 9 10 11]
[12 13 14]
[15 16 17]]
[[18 19 20]
[21 22 23]
[24 25 26]]
[[27 28 29]
[30 31 32]
[33 34 35]]]
address of a:(2994087337488, False)
b:[[[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]]
[[ 9 10 11]
[12 13 14]
[15 16 17]]]
[[[18 19 20]
[21 22 23]
[24 25 26]]
[[27 28 29]
[30 31 32]
[33 34 35]]]]
address of b:(2994087337488, False)
"""
可以看到地址并未改变。这里我的个人理解方法是,从最右侧维度开始填充,按照相应strides进行取值,取够当前维度的元素数接着按照-2维的strides以及元素数进行取值,以此类推
3. 转置:
从二维矩阵转置操作出发,转置操作实质上只是改变了shape和strides属性:
mat_a = np.arange(12).reshape(3, 4)
print(f"mat_a shape:{mat_a.shape}\n")
print(f"mat_a strides:{mat_a.strides}\n")
print(f"mat_a:{mat_a}\n")
mat_b = mat_a.transpose((1, 0))
print(f"mat_b shape:{mat_b.shape}\n")
print(f"mat_b strides:{mat_b.strides}\n")
print(f"mat_b:{mat_b}\n")
"""
outputs:
mat_a shape:(3, 4)
mat_a strides:(16, 4)
mat_a:[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
mat_b shape:(4, 3)
mat_b strides:(4, 16)
mat_b:[[ 0 4 8]
[ 1 5 9]
[ 2 6 10]
[ 3 7 11]]
"""
矩阵转置时,strides也跟随矩阵进行了相应的转置,这样,对于mat_b而言,对于第1维,每隔4个元素取一个值,取3次,获得[0, 4, 8],然后开始第0维的下一个元素,起始元素与0相差1个元素,然后与起始元素每隔4个元素取一个值取3次,得到[1,5,9]以此类推。最终得到我们熟悉的转置后的矩阵b
由以上的理解推广到高维张量数组以及PixelShuffle操作
对于高维数组的维度之间的置换操作而言,我们一定要记得,strides属性也进行了对应的置换,从strides属性出发,我们才能明晰置换维度后的数组内的元素究竟会变成什么样子,是否符合我们想要的结果:
a = np.arange(36).reshape((4, 3, 3))
print(f"stride of a:{a.strides}\n")
print(f"shape of a:{a.shape}\n")
print(f"base of a:{a.base}\n")
print(f"a:{a}\n")
print(f"address of a:{a.__array_interface__['data']}\n")
b = a.reshape((2, 2, 3, 3))
print(f"stride of b:{b.strides}\n")
print(f"shape of b:{b.shape}\n")
print(f"base of b:{b.base}\n")
print(f"b:{b}\n")
print(f"address of b:{b.__array_interface__['data']}\n")
c = b.transpose((2, 0, 3, 1))
print(f"stride of c:{c.strides}\n")
print(f"shape of c:{c.shape}\n")
print(f"base of c:{c.base}\n")
print(f"c:{c}\n")
print(f"address of c:{c.__array_interface__['data']}\n")
d = c.reshape((6, 6))
print(f"stride of d:{d.strides}\n")
print(f"shape of d:{d.shape}\n")
print(f"base of d:{d.base}\n")
print(f"d:{d}\n")
print(f"address of d:{d.__array_interface__['data']}\n")
"""
outputs:
stride of a:(36, 12, 4)
shape of a:(4, 3, 3)
base of a:[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
24 25 26 27 28 29 30 31 32 33 34 35]
a:[[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]]
[[ 9 10 11]
[12 13 14]
[15 16 17]]
[[18 19 20]
[21 22 23]
[24 25 26]]
[[27 28 29]
[30 31 32]
[33 34 35]]]
address of a:(2959935178752, False)
stride of b:(72, 36, 12, 4)
shape of b:(2, 2, 3, 3)
base of b:[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
24 25 26 27 28 29 30 31 32 33 34 35]
b:[[[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]]
[[ 9 10 11]
[12 13 14]
[15 16 17]]]
[[[18 19 20]
[21 22 23]
[24 25 26]]
[[27 28 29]
[30 31 32]
[33 34 35]]]]
address of b:(2959935178752, False)
stride of c:(12, 72, 4, 36)
shape of c:(3, 2, 3, 2)
base of c:[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
24 25 26 27 28 29 30 31 32 33 34 35]
c:[[[[ 0 9]
[ 1 10]
[ 2 11]]
[[18 27]
[19 28]
[20 29]]]
[[[ 3 12]
[ 4 13]
[ 5 14]]
[[21 30]
[22 31]
[23 32]]]
[[[ 6 15]
[ 7 16]
[ 8 17]]
[[24 33]
[25 34]
[26 35]]]]
address of c:(2959935178752, False)
stride of d:(24, 4)
shape of d:(6, 6)
base of d:[[[[ 0 9]
[ 1 10]
[ 2 11]]
[[18 27]
[19 28]
[20 29]]]
[[[ 3 12]
[ 4 13]
[ 5 14]]
[[21 30]
[22 31]
[23 32]]]
[[[ 6 15]
[ 7 16]
[ 8 17]]
[[24 33]
[25 34]
[26 35]]]]
d:[[ 0 9 1 10 2 11]
[18 27 19 28 20 29]
[ 3 12 4 13 5 14]
[21 30 22 31 23 32]
[ 6 15 7 16 8 17]
[24 33 25 34 26 35]]
address of d:(2959935177312, False)
"""
这里需要注意的是d的地址发生了变化,说明d是深拷贝得来的。原因可以参考pytorch中view()和reshape()的解析. 我的理解numpy对于维度置换后再进行reshape操作时深拷贝的原因应是和pytorch中相同的。由于c是经过维度置换的,不满足前面参考博文中所讲述的连续性条件,因此再对c进行reshape时,会开辟一块内存空间,对于c中的元素按照行向量化,进行连续存储,再在这块连续存储的区域上进行reshape操作。而此时d的base也随之发生变化,变为新的连续内存空间中存的数组。
由于d的shape是(6,6),也就是PixelShuffle操作后的上采样结果,那么从最右侧的维度出发计算strides,可以推出strides:([24, 4])(字节表示),([6, 1])(元素个数表示)。根据这个strides取值填充d数组,得到上述d的打印结果,可以看到是符合PixelShuffle的设计所期望的输出结果的。
不置换维度,直接reshape的情况
首先说明,这样做对于pixel shuffle而言肯定结果是错误的,只是运行一下代码加深对reshape、strides的理解
e = a.reshape((3, 2, 3, 2))
print(f"stride of e:{e.strides}\n")
print(f"shape of e:{e.shape}\n")
print(f"base of e:{e.base}\n")
print(f"e:{e}\n")
print(f"address of e:{e.__array_interface__['data']}\n")
f = e.reshape((6, 6))
print(f"stride of f:{f.strides}\n")
print(f"shape of f:{f.shape}\n")
print(f"base of f:{f.base}\n")
print(f"f:{f}\n")
print(f"address of f:{f.__array_interface__['data']}\n")
"""
outputs:
#注意博主分了两次运行,所以地址发生了变化,但是对于变量而言,a的地址和e、f是相同的
stride of e:(48, 24, 8, 4)
shape of e:(3, 2, 3, 2)
base of e:[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
24 25 26 27 28 29 30 31 32 33 34 35]
e:[[[[ 0 1]
[ 2 3]
[ 4 5]]
[[ 6 7]
[ 8 9]
[10 11]]]
[[[12 13]
[14 15]
[16 17]]
[[18 19]
[20 21]
[22 23]]]
[[[24 25]
[26 27]
[28 29]]
[[30 31]
[32 33]
[34 35]]]]
address of e:(1635184876144, False)
stride of f:(24, 4)
shape of f:(6, 6)
base of f:[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
24 25 26 27 28 29 30 31 32 33 34 35]
f:[[ 0 1 2 3 4 5]
[ 6 7 8 9 10 11]
[12 13 14 15 16 17]
[18 19 20 21 22 23]
[24 25 26 27 28 29]
[30 31 32 33 34 35]]
address of f:(1635184876144, False)
Process finished with exit code 0
"""
由以上结果可以看到,不进行维度置换,而只进行reshape操作时,始终是对原始的连续内存空间中存储的数组按照相应的strides进行取值的。这种情况下,strides可以这样推算: 最右侧维度(即-1维)总是每隔一个元素取值(满足连续性条件)(未进行过维度置换的情况),维度从右往左按照形状依次推出其他维度对应的stride。例如f的形状是(6,6),我们首先知道strides f:([xx, 4]),因为在只进行reshape的情况下,-1维的strides一定是4字节(对于int,即1个元素),而-1维有6个元素,那么对于第0维而言就是每隔6个元素取一个值即24字节,因此strides f:([24, 4])
总结
以上就是基于numpy的维度置换加reshape操作的理解。之后自己在遇到需要设计类似pixelshuffle的涉及维度置换等操作的算子时,可以参考这些步骤去推导维度置换加reshape后的结果。