None作为ndarray或tensor的索引作用是增加维度,与 pytorch中的 torch.unsqueeze()
或 tensorflow 中的tf.expand_dims()
作用相同
例子:
In [5]: t=torch.from_numpy(np.arange(12).reshape(3,4))
In [6]: t
Out[6]:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
In [7]: t.dim()
Out[7]: 2
In [8]: t[:,None,:]
Out[8]:
tensor([[[ 0, 1, 2, 3]],
[[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11]]])
In [9]: arr=np.arange(12).reshape(3,4)
In [10]: arr
Out[10]:
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
In [11]: arr[:,None,:]
Out[11]:
array([[[ 0, 1, 2, 3]],
[[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11]]])
In [12]: arr[:,None,None,:]
Out[12]:
array([[[[ 0, 1, 2, 3]]],
[[[ 4, 5, 6, 7]]],
[[[ 8, 9, 10, 11]]]])
In [13]: arr[:,None,None,:].shape
Out[13]: (3, 1, 1, 4)
In [14]: t
Out[14]:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
In [15]: t.unsqueeze(1)
Out[15]:
tensor([[[ 0, 1, 2, 3]],
[[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11]]])
In [16]: t
Out[16]:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
In [17]: t.unsqueeze(1)==t[:,None,:]
Out[17]:
tensor([[[True, True, True, True]],
[[True, True, True, True]],
[[True, True, True, True]]])