import torch
x = torch.tensor([[True,False,False,False],[False,True,True,True],[False,True,True,True]]) #torch.Size([3, 4])
a,b=torch.where(x)
a: tensor([0, 1, 1, 1, 2, 2, 2]) #值为True的第一维索引
b: tensor([0, 1, 2, 3, 1, 2, 3]) #值为True的第二维索引
将x 看成是3行4列的矩阵,值为True的索引为第0行0列,1行1列,1行2列: (0,0),(1,1),(1,2),(1,3),(2,1),(2,2),(2,3)
torch.where函数使用
最新推荐文章于 2024-09-10 21:14:51 发布