bboxes = torch.cat([
torch.arange(
bs, requires_grad=False
).to(bboxes.device).repeat_interleave(self.num_objects).reshape(-1, 1),
bboxes.flatten(0, 1),
], dim=1)
# print(bboxes)
tensor([[ 0.0000, -0.6644, 0.8709, -0.7267, -2.2695],
[ 0.0000, 1.4619, -1.0581, -1.8898, 0.7876],
[ 0.0000, -1.6069, -1.3537, 0.5698, 0.6994],
[ 1.0000, 0.2495, -0.5027, -1.5294, 0.0376],
[ 1.0000, -0.8491, 0.4727, 0.4423, 0.5050],
[ 1.0000, 0.3200, -0.1111, -1.0900, -2.4595],
[ 2.0000, -0.7400, -0.2743, 1.6047, -0.5900],
[ 2.0000, 0.3749, -0.0864, 1.2539, -0.9245],
[ 2.0000, -0.1543, -1.1329, 1.7370, 0.7067],
[ 3.0000, 1.6104, -1.0426, 0.8122, -0.1253],
[ 3.0000, -0.6058, -0.5058, -1.2164, -0.7558],
[ 3.0000, -0.0525, -1.8111, -0.8256, -1.1813]])
12个坐标 4个Batch
一个图3个框
∴ 4个图 12个框 ==》 12个坐标