多级筛选:
比如结构是2*2*3,只想选第三维的最大的
tx[index, best_n, g_y_center, g_x_center]
index=[01],best_n=[0,1]
最后只取两个值,第一行,第1列,第二行,第2列的。
筛选第3维最大的值,下面的代码不对,解决方法:查询max源码
也可以把3维用view降到2维再计算就可以了。
import torch
anch_ious = torch.Tensor([[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]]])
print('b shape',anch_ious.shape)
b = torch.max(anch_ious, 2)
print(b[0])
print(b[1])
b = b[1].squeeze(1)
print(b)
print(anch_ious[list(range(anch_ious.size(0))),list(range(anch_ious.size(1))), b])
通过值筛选:
import torch
x = torch.linspace(1, 8, steps=8).view(4, 2)
#筛选第一维和第二维都>5.5的
print(x)
area=(x[:,0]>5.5)&(x[:,1]>5.5)
b=x[area]
# b= x[torch.where((x[:,0]&g