使用numpy.expand_dims可以将numpy.ndarray数据增加一个维度,数值为1,与np.squeeze相反,类似于pytorch中的torch.squeeze和torch.unsqueeze
numpy.expand_dims(a, axis)
a: numpy.ndarray数据
axis: 增加维度的轴,取值为[0, a.ndim]
【sample】
In [1]: import numpy as np
In [2]: dat = np.array([1,2,3])
In [3]: dat.ndim
Out[3]: 1
In [4]: d1 = np.expand_dims(dat, 0)
In [5]: d1.shape
Out[5]: (1,3)
In [6]: d2 = np.expand_dims(dat, 1)
In [7]: d2.shape
Out[7]: (3,1)