为什么np.expand_dims会报错
在 NumPy 中使用 np.expand_dims
函数来扩展数组的维度时,你不能指定一个超出数组当前维度范围的轴(axis)。数组的维度是从 0 开始的,依次增加。对于一个形状为
(
m
,
n
,
p
)
(m, n, p)
(m,n,p) 的三维数组,有效的轴范围是 0、1、2(也可以是负数,但必须在有效范围内)。
在你给出的例子中:
arr =
[[9, 77, 65, 44],
[43, 86, 9, 70],
[20, 10, 53, 97]]
这是一个形状为 ( 3 , 4 ) (3, 4) (3,4) 的二维数组,因此有效的轴范围是 0 和 1。
当你尝试使用 axis=(1, 3, 8)
时,3 和 8 显然超出了这个有效范围,因此会导致错误。
如果你想在现有轴之间或之前/之后添加新的轴(维度),你必须确保新的轴索引不超过扩展后的维度数。例如,对于一个形状为 ( 3 , 4 ) (3, 4) (3,4) 的二维数组,你可以添加一个新的轴使其变为三维,四维等,但新的轴的索引必须在 0 到 新维度数-1 的范围内。
要修正这个问题,你应确保指定的轴在扩展后的维度范围内。例如,你可以使用 axis=(0, 2, 3)
来从二维扩展到五维(原始的两个维度加上三个新的维度)。
举例子说明
当你使用 np.expand_dims(arr, 1)
时,你是在第 1 轴(记住,索引是从 0 开始的)上添加一个新的维度。在这个情况下,原始数组 arr
的形状是 ( (3, 4) )。
- 第 0 轴有 3 个元素。
- 第 1 轴有 4 个元素。
当你在第 1 轴上添加一个新的维度时,新的维度会被插入到第 1 轴的位置,并将原有的第 1 轴(和所有后续的轴)向后移动。新的维度的大小是 1,因此新的数组形状变为 ( (3, 1, 4) )。
简单地说,np.expand_dims
在指定的轴位置“插入”了一个新的大小为 1 的维度,而不是在数组的末尾添加。这就是为什么结果是 ( (3, 1, 4) ) 而不是 ( (3, 4, 1) )。如果你想得到 ( (3, 4, 1) ),你应该使用 np.expand_dims(arr, 2)
。
再举一个例子
在这个例子中,你有一个形状为 ( (3, 4) ) 的二维数组。当你使用 axis=(1, 3, 4)
来扩展维度时,这里是怎么工作的:
-
第一次扩展(
axis=1
): 在第 1 轴(索引从 0 开始)上添加一个新维度。这会将数组形状从 ( (3, 4) ) 变为 ( (3, 1, 4) )。 -
第二次扩展(
axis=3
): 现在数组已经是三维的(( (3, 1, 4) ))。在第 3 轴上添加一个新的维度,这会将数组形状从 ( (3, 1, 4) ) 变为 ( (3, 1, 4, 1) )。 -
第三次扩展(
axis=4
): 现在数组是四维的(( (3, 1, 4, 1) ))。在第 4 轴上添加一个新的维度,这会将数组形状从 ( (3, 1, 4, 1) ) 变为 ( (3, 1, 4, 1, 1) )。
因此,最终的数组形状会是 ( (3, 1, 4, 1, 1) )。
需要注意的是,每次扩展维度后,有效的轴索引范围都会增加。这就是为什么你可以在一个初始为二维的数组上使用 axis=(1, 3, 4)
而不出错的原因:每次扩展都会增加一个新的维度,使得后续的轴索引变得有效。