1.np.reshape,np.transpose和axis
在阅读YOLO V1代码过程中,出现了一段代码:
self.offset = np.transpose(np.reshape(np.array( #reshape之后再转置,变成7*7*2的三维数组
[np.arange(self.cell_size)] * self.cell_size * self.boxes_per_cell),
(self.boxes_per_cell, self.cell_size, self.cell_size)), (1, 2, 0))
其中的self.cell_size=7,self.boxes_per_cell=2,那么翻译一下,就是:
np.transpose(np.reshape(np.array([np.arange(7)] * 7 * 2),(2, 7, 7)), (1, 2, 0))
我们来逐一解读:
先从:np.array([np.arange(7)] * 7 * 2开始,这个简单,就是[array([0, 1, 2, 3, 4, 5, 6])*14:
就是:
[array([0, 1, 2, 3, 4, 5, 6]), array([0, 1, 2, 3, 4, 5, 6]), array([0, 1, 2, 3, 4, 5, 6]), array([0, 1, 2, 3, 4, 5, 6]), array([0, 1, 2, 3, 4, 5, 6]), array([0, 1, 2, 3, 4, 5, 6]), array([0, 1, 2, 3, 4, 5, 6]), array([0, 1, 2, 3, 4, 5, 6]), array([0, 1, 2, 3, 4, 5, 6]), array([0, 1, 2, 3, 4, 5, 6]), array([0, 1, 2, 3, 4, 5, 6]), array([0, 1, 2, 3, 4, 5, 6]), array([0, 1, 2, 3, 4, 5, 6]), array([0, 1, 2, 3, 4, 5, 6])]
共14个array([0, 1, 2, 3, 4, 5, 6])。
现在到np.reshape(np.array([np.arange(7)] * 7 * 2),(2, 7, 7)),就是将上述的14组array([0, 1, 2, 3, 4, 5, 6])转为shape(2,7,7)
这里我们插一句话:shape(2,7,7),就是说有三个维度,axis=0的维度为2,axis=1的维度为7,axis=2的维度为7。我们再拓展说明下:
什么是axis?
我们常说的数组的shape为(5,3),就是axis=0的维度为5,axis=1的维度为3,那么经常就有人理解为axis=0就是行的意思,axis=1就是列的意思,这么说,在更高维的数组中,经常会产生绕晕感。我们在二维数组中,这样定义是没错的,axis=0是按行延展,axis=1按列延展,如图1所示。
![cada28de2e231d9d815871e78a740247.png](https://i-blog.csdnimg.cn/blog_migrate/f852dcc7d05bdbbd5727b14b062a23a2.jpeg)
但是在更高维的数组中,比如我们经常说的图片,或者是输入到卷积神经网络中的图片,shape=(batch_size,height,width,channels),这里的batch_size是axis=0的维度,就是图片的张数,height是图片的高(axis=1),width是图片的宽(axis=2),channels是图片的通道数(axis=3),如此说明,我们看出,其实按照axis进行排序的话,axis里面的内容会越来越精细。
我们将np.reshape(np.array([np.arange(7)] * 7 * 2),(2, 7, 7))输入到python 的IDLE中,得到了图2
![5371705c832d7e02d3178f1fa321e79d.png](https://i-blog.csdnimg.cn/blog_migrate/8b411894eba6f804379fc980b9b47796.jpeg)
我们根据shape为(2,7,7)可以看出,这个array有两个大的数组(从最外层的[ ]开始分解),其中大的数组如下:
[[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6]]
这个大的数组的shape为(7,7),每行是[0, 1, 2, 3, 4, 5, 6]。所以随着axis增大,内容是原来越精细的。
最后一步,np.transpose(np.reshape(np.array([np.arange(7)] * 7 * 2),(2, 7, 7)), (1, 2, 0)),这个np.transpose是numpy中的一个转置函数,如果很多人和我一样,真的在脑子里尝试转置这个(2,7,7)数组(图2),然后将axis从(0,1,2)转到(1,2,0),估计很多人和我一样,脑子转不过来。
那么我们就换种方法,原来的shape是(2,7,7),按照(1,2,0)转置后的shape是(7,7,2),行呀,我们脑子里搭建一个shape=(7,7,2)的全零数组。然后嘞,我们简单点,比如图一中处坐标位置(1,3,4)的值为4,这个坐标位置(1,3,4)通过(1,2,0)的转置后的坐标位置为(3,4,1)的坐标就是4,我们不妨就将shape=(7,7,2)中坐标位置为(3,4,1)的0值替换成4。
不断迭代,这样就将图2中所有元素都映射到shape=(7,7,2)全零数组中了,
生成了我们需要的转置后的结果了(如图3),其实transpose就是通过坐标索引变换,将值从一个坐标索引到另一个坐标中罢了。
![6c6611915a3ceb8dbada6b16878cb4d9.png](https://i-blog.csdnimg.cn/blog_migrate/e44805ce9e1a111388cac28aacc46c55.png)