图像输出中,以六分类为例子,输出结果是【5,1,3,512,11】
维度解释:
batchsize=5 ,512个格点,每个格点铺设3个框
数据处理目的:
筛选出置信度维度 最后一维【11】中的第5个数值,筛选出置信度最大的锚框
难点:
1、取得置信度最大锚框的索引
2、根据索引在原输出结果中取出
torch.manual_seed(1)
x=torch.rand(5,3,1,512,11)
'''从 11 的维度中取出 '''
_,location = torch.max(x[...,4],dim=-1)
'''x[...,3] shape torch.Size([5, 3, 1, 512]) '''
'''目标是 找出 512 最大的位置,并在原来的prediction中取出对应维度'''
print(x.shape)
print(location.shape)
torch.manual_seed(1)
y=torch.rand(5, 3, 1, 1, 11)
for q in range(location.shape[0]):
for w in range(location.shape[1]):
for e in range(location.shape[2]):
y[q,w,e,0,:]=x[q,w,e,location[q,w,e].item(),:]
print(y.shape)