一、numpy.squeeze函数
- 作用:从指定数组中删除长度为1的维度
- 用法:
numpy.squeeze(a, axis=None) # 或者torch.squeeze(a, axis=None)
- a为指定数组
- axis不为None时,指定的维度必须是长度为1的单维度;为None时,删除所有长度为1的单维度
- 举例
import numpy as np
a = np.arange(2).reshape(2, 1, 1)
# array([[[0]],
# [[1]]])
np.squeeze(a).shape
# (2,)
np.squeeze(a, axis=1).shape
# (2, 1)
np.squeeze(a, axis=2).shape
# (2, 1)
二、torch.unsqueeze函数
和squeeze
作用相反:在指定数组中加入长度为1的维度。用法类似。
三、torch.expand函数
如下图: