感觉好多教程都是举例子让自己理解,没有说清楚。本文将直观地解释stack堆叠方法。
变量,如代码所示:
import numpy as np
a = np.array([[1,2],[3,4],[5,6]])
b = np.array([[10,20],[30,40],[50,60]])
c = np.stack([a,b], axis = 0)
d = np.stack([a,b], axis = 1)
e = np.stack([a,b], axis = 2)
a和b的形状都是(3, 2),其维度的直观理解如下图所示(图中颜色框内的数字为索引序号):
stack堆叠起来,意味着原来的两维变成三维。“三维”也就是说堆叠的方式有三种:
-
0维上堆叠,在新变量的0维上区分a,b,即把a,b直接作为新变量0维的元素并列起来。
-
1维上堆叠,在新变量的1维上区分a,b。
-
2维上堆叠,在新变量的2维上区分a,b。
根据以上说明,很容易知道:n个变量在x维堆叠(axis=x),则堆叠产生的新变量的第x维的维数为n;或者说,新变量的shape值为原shape值在第x位插入数字n。