作者:机器视觉全栈er
来源:cvtutorials.com
2.1.2 索引
筛选出符合某种条件的subtensor。
torch.where: 根据布尔变量的值选择tensor中的元素,用法如下:
torch.where(condition, x, y)
下面举个简单的例子:
>>> import torch
>>> cvtutorials = torch.randn(3, 4)
>>> threshold = torch.zeros(3, 4)
>>> cvtutorials
tensor([[-1.6981, 1.0443, 2.7922, -0.8736],
[-2.0208, -0.4815, -0.1488, -0.9714],
[ 1.1035, 0.4089, 0.6279, 2.4600]])
>>> torch.where(cvtutorials > 0, cvtutorials, threshold)
tensor(