写在前面的话
这个系列的文章只是出于我这个萌新在学习深度学习框架过程中遇到的各种难点,因此用自己能够理解的方式做了一个笔记记录下来,某些知识点解释只是方便我自己的记忆和理解,可能不是那么准确,如果有误,希望能够帮忙指正,谢谢!
直观速记
假设arrays的维度为(3,4,5),可以理解为3个4行5列的数组,那么经过`numpy.stack()的变化后,得到的维度为
import numpy
arr.shape # (3,4,5)
numpy.stack(arr, axis=0) # (3,4,5)
numpy.stack(arr, axis=1) # (4,3,5)
numpy.stack(arr, axis=2) # (4,5,3)
也就是说,axis=x,就是将原本shape中的第0个维度挪到
第x维度;
因此axis=1
可以用来对二维矩阵做转置。
具体细节
(以下内容懒得自己打了,引自博客)
下文引用部分来自我的理解与批注
numpy.stack()
stack英文之意即为堆叠,故该函数的作用就是实现输入数个数组不同方式的堆叠,返回堆叠后的1个数组。
因此虽然axis=0虽然看起来没有变化,其实是因为我们输入的arr数组其实是看作数个数组的原因。
参数 | 描述 |
---|---|
入口参数1 | arrays,用来作为堆叠的数个形状维度相等的数组。 |
入口参数2 | axis,即指定依照哪个维度进行堆叠,也就是指定哪种方式进行堆叠数组。 |
输出 | 堆叠后的1个数组 |
程序举例
首先定义4个2行3列的数组,这里为什么定义4个以及为什么所有的数组里的数据都不同,我想从下面的实验中你会找到答案。
import numpy
a=numpy.arange(1, 7).reshape((2, 3))
b=numpy.arange(7, 13).reshape((2, 3))
c=numpy.arange(13, 19).reshape((2, 3))
d=numpy.arange(19, 25).reshape((2, 3))
这四个用于堆叠的数组如下所示:
[[1 2 3]
[4 5 6]]
[[ 7 8 9]
[10 11 12]]
[[13 14 15]
[16 17 18]]
[[19 20 21]
[22 23 24]]
接下来进行不同方式也就是不同维度的堆叠实验。
axis=0
print(numpy.stack([a, b,c,d], axis=0))
print(numpy.stack([a, b,c,d], axis=0).shape)
输出结果:
[[[ 1 2 3]
[ 4 5 6]]
[[ 7 8 9]
[10 11 12]]
[[13 14 15]
[16 17 18]]
[[19 20 21]
[22 23 24]]]
(4, 2, 3)
形象理解:axis等于几就说明在哪个维度上进行堆叠。当axis=0的时候,意味着整体,也就是一个2行3列的数组。所以对于0维堆叠,相当于简单的物理罗列,比如这四个数组代表的是4张图像的数据,进行0维堆叠也就是把它们按顺序排放了起来,形成了一个(4,2,3)的3维数组。
axis=1
print(numpy.stack([a, b,c,d], axis=1))
print(numpy.stack([a, b,c,d], axis=1).shape)
输出结果:
[[[ 1 2 3]
[ 7 8 9]
[13 14 15]
[19 20 21]]
[[ 4 5 6]
[10 11 12]
[16 17 18]
[22 23 24]]]
(2, 4, 3)
形象理解:**axis等于几就说明在哪个维度上进行堆叠。**当axis=1的时候,意味着第一个维度,也就是数组的每一行。所以对于1维堆叠,4个2行3列的数组,各自拿出自己的第一行数据进行堆叠形成3维数组的第一“行”,各自拿出自己的第二行数据进行堆叠形成3维数组的第二“行”,从而形成了一个(2,4,3)的3维数组。比如这四个数组分别代表的是对同一张图像进行不同处理后的数据,进行1维堆叠可以将这些不同处理方式的数据有条理的堆叠形成一个数组,方便后续的统一处理。
axis=2
print(numpy.stack([a, b,c,d], axis=2))
print(numpy.stack([a, b,c,d], axis=2).shape)
输出结果:
[[[ 1 7 13 19]
[ 2 8 14 20]
[ 3 9 15 21]]
[[ 4 10 16 22]
[ 5 11 17 23]
[ 6 12 18 24]]]
(2, 3, 4)
形象理解:**axis等于几就说明在哪个维度上进行堆叠。**当axis=2的时候,意味着第二个维度,也就是数组的每一行中的更深一层的维度。对于本例的2维数组来说就到了单个的数据。**注意:千万不要理解成2维堆叠是对每一列拿出来进行堆叠!**试想如果进行列堆叠,跟1维堆叠又有什么区别呢?不就是个转置的事嘛~所以希望大家对于维度有个更深刻的认识,我觉得维度的概念就是逐层深入的。比如我们这里谈到的整体、行、单个数据。所以对于2维堆叠,4个2行3列的数组,各自拿出自己每个数据,在对应的位置,进行堆叠。从而形成了一个(2,3,4)的3维数组。假如我们的4个数组分别代表的是一个图像的R、G、B数据(这儿有点不恰当,因为其实该是3个数组,但是我觉得这个例子会更形象,所以大家就不要在意这点细节啦),那我们进行2维堆叠后,不就形成了一个图像的RGB三通道的数组了嘛。
与numpy.concatenate()函数的区别
concatenate()函数用于实现数组的拼接,stack()会增加数组的维度,而concatenate()不会增加数组的维度。
**举例说明:**比如我们有4个shape为(2,3)的数组,这4个数组代表着4个不同的图像数据,在使用stack(axis=0)函数后,shape变为了(4,2,3),可以看出增加了一个维度,而增加的这个维度表征着不同的图像;在使用concatenate(axis=0)函数后,其shape变为了(8,3),可以看出维度并没有变化,相当于我们把这四张图片,在行维度上拼接起来,即维度不变,而是对维度内的内容进行充实,所以原来4张2行3列的图像,合成为了一张8行3列的图像。