import torch
import numpy as np
z = torch.tensor(np.random.rand(2,5,3))
print(z[None].shape)
# torch.Size([[1, 2, 5, 3]]) 增加一个维度到第0个维度,这里等同于unsqueeze(0)
print(z[:,None].shape)
# torch.Size([2, 1, 5, 3]) 等价于z[:,None,:] 第0个维度不变,增加一个维度到第1个维度,剩下其他维度不变
print(z[...,None].shape)
# torch.Size([2, 5, 3, 1]) 增加一个维度到最后1个维度,...代表前面所有的维度不变
print(z[:,:,None].shape)
# torch.Size([2, 5, 1, 3])
P_mask = z[:, None]==z[None]
# torch.Size([2, 2, 5, 3]) 维度不同处,expand
python切片的高级用法
于 2022-08-02 12:17:54 首次发布