- 定义
numpy.expand_dims(a, axis)
该功能是改变数组维度,在axis参数指定的axis维度添加新的axis,从而为了未来的深度学习中选用适当的维度的tensor。与np.squeeze()函数互为逆操作。
参数:
a:array_like
axis: int 或 int的tuple
返回:
result:ndarray - 使用示例
这里举两个例子,对于一维数组和二维矩阵的维度扩展。
- 例子1
x = np.array([1,2,3])
print(x.shape)
y = np.expand_dims(x,axis=0)
print(y.shape)
输出结果:
(3,)
(1, 3)
- 例子2
x = np.array([[1,2,3],[4,5,6]])
print(x.shape)
y = np.expand_dims(x,axis=(1,0))# 在axis=0和axis=1两个维度上进行维度加1
print(y.shape)
输出结果:
(2, 3)
(1, 1, 2, 3)
- 题外话
其实刚开始用numpy的时候,对于高维数组比如上面示例中的shape为(1,1,2,3)很难想象,对于axis的索引方式难免混乱,在这里一并做出总结。
- 解释numpy中的高维数组
- 理解numpy中的axis
参考链接:
https://numpy.org/doc/stable/reference/generated/numpy.expand_dims.html