np.expand_dims的作用是通过在指定位置插入新的轴来扩展数组形状,
函数格式如下:
np.expand_dims(array, axis)
当axis=0时,我们从最外面的括号开始看,axis=0是最高维度,axis=1就是从左至右第二个括号所包含的内容,以此类推。
import numpy as np
a = np.array([[1, 2, 3], [4, 5, 6]])
print(a.shape)
a.shape(2, 3)
(1)当axis = 0 时,我们从最高维扩展数组(增加它的维度),就给它在最外面再添个括号就可以啦,即yoz平面。所以,
b = np.expand_dims(a, axis=0)
b = [[[1, 2, 3], [4, 5, 6]]]
b.shape = (1, 2, 3)
shape(1, 2, 3)表示从左至右(从高维到低维:axis = 0 , 1, 2)所包含内容的情况,axis = 0(最外面的括号)只包含1个内容:[[1, 2, 3], [4, 5, 6]]; axis = 1(第二个括号)里包含2个内容:[1, 2, 3], [4, 5, 6]。axis = 3(第三个括号)里包含3个内容:1,2,3;或4,5,6
(2)当axis = 1 时,就是给第二个括号再加一个括号啦,即xoz平面
c = np.expand_dims(a, axis=1)
c = [[[1, 2, 3]], [[4, 5, 6]]]
c.shape = (2, 1, 3)
(3)当axis = 2 时,就是给最里面的数加一个括号,即xoy平面
d = np.expand_dims(a, axis=2)
d = [[[1], [2], [3]], [[4], [5], [6]]]
d.shape = (2, 3, 1)