看源代码的学习笔记
源代码
def _get_proposal_pairs(self, proposals):
proposal_pairs = []
for i, proposals_per_image in enumerate(proposals):
box_subj = proposals_per_image.bbox
box_obj = proposals_per_image.bbox
#unsqueeze()的作用是用来增加给定tensor的维度的,指定增加1维
#squeeze()的作用就是压缩维度,直接把维度为1的维给去掉。形式上表现为,去掉一层[]括号。
#直接将repeat参数对应乘上原tensor的shape就得到最终的shape
# tensor.view()这个函数有点类似reshape的功能,不想手动的去计算其他的维度值,就可以使用view(-1)
# 简单的理解就是:先把一个tensor转换成一个一维的tensor,然后再组合成指定维度的tensor。
#具体计算见biji 2
box_subj = box_subj.unsqueeze(1).repeat(1, box_subj.shape[0], 1)
box_obj = box_obj.unsqueeze(0).repeat(box_obj.shape[0], 1, 1)
proposal_box_pairs = torch.cat(
(box_subj.view(-1, 4), box_obj.view(-1, 4)), 1)
idx_subj = torch.arange(box_subj.shape[0]).view(-1, 1, 1).repeat(1, box_obj.shape[0], 1).to(
proposals_per_image.bbox.device)
idx_obj = torch.arange(box_obj.shape[0]).view(1, -1, 1).repeat(box_subj.shape[0], 1, 1).to(
proposals_per_image.bbox.device)
proposal_idx_pairs = torch.cat((idx_subj.view(-1, 1), idx_obj.view(-1, 1)), 1)
#具体计算见biji 3
label_subj = proposals_per_image.get_field('labels')[idx_subj]
label_obj = proposals_per_image.get_field('labels')[idx_obj]
proposal_label_pairs = torch.cat(
(label_subj.view(-1, 1), label_obj.view(-1, 1)), 1)
keep_idx = (proposal_idx_pairs[:, 0] != proposal_idx_pairs[:, 1]).nonzero(as_tuple=False).view(-1)
# if we filter non overlap bounding boxes
if self.cfg.MODEL.ROI_RELATION_HEAD.FILTER_NON_OVERLAP:
ious = boxlist_iou(proposals_per_image, proposals_per_image).view(-1)
ious = ious[keep_idx]
keep_idx = keep_idx[(ious > 0).nonzero(as_tuple=False).view(-1)]
proposal_idx_pairs = proposal_idx_pairs[keep_idx]
proposal_box_pairs = proposal_box_pairs[keep_idx]
proposal_label_pairs = proposal_label_pairs[keep_idx]
proposal_pairs_per_image = BoxPairList(proposal_box_pairs, proposals_per_image.size, proposals_per_image.mode)
proposal_pairs_per_image.add_field("idx_pairs", proposal_idx_pairs)
proposal_pairs_per_image.add_field("label_pairs", proposal_label_pairs)
proposal_pairs.append(proposal_pairs_per_image)
return proposal_pairs
举例解释
# 假设 proposals_per_image.bbox 是一个形状为 (3, 4) 的张量
# 每个目标区域由四个值表示(左上角坐标x、y,宽度和高度)
proposals_per_image_bbox = torch.tensor([[0, 0, 2, 2], [1, 1, 3, 3], [2, 2, 4, 4]])
# 复制 proposals_per_image_bbox,得到 box_subj 和 box_obj
box_subj = proposals_per_image_bbox
box_obj = proposals_per_image_bbox
# 对 box_subj 和 box_obj 进行扩展
box_subj = box_subj.unsqueeze(1).repeat(1, box_subj.shape[0], 1)
box_obj = box_obj.unsqueeze(0).repeat(box_obj.shape[0], 1, 1)
# 打印 box_subj 的值
print("box_subj:")
print(box_subj)
print(box_subj.view(-1, 4))
# 将扩展后的 box_subj 和 box_obj 连接成 proposal_box_pairs
proposal_box_pairs = torch.cat((box_subj.view(-1, 4), box_obj.view(-1, 4)), 1)
# 生成目标对的索引
idx_subj = torch.arange(box_subj.shape[0]).view(-1, 1, 1).repeat(1, box_obj.shape[0], 1).to(proposals_per_image_bbox.device)
print("torch.arange(box_subj.shape[0]).view(-1, 1, 1)")
print(torch.arange(box_subj.shape[0]).view(-1, 1, 1))
print(idx_subj)
idx_obj = torch.arange(box_obj.shape[0]).view(1, -1, 1).repeat(box_subj.shape[0], 1, 1).to(proposals_per_image_bbox.device)
print("torch.arange(box_obj.shape[0]).view(1, -1, 1)")
print(torch.arange(box_obj.shape[0]).view(1, -1, 1))
print(idx_obj)
proposal_idx_pairs = torch.cat((idx_subj.view(-1, 1), idx_obj.view(-1, 1)), 1)
# 打印结果
print("\nOriginal Bounding Boxes:")
print(proposals_per_image_bbox)
print("\nExpanded and Concatenated Bounding Box Pairs:")
print(proposal_box_pairs)
print("\nIndex Pairs for Bounding Box Pairs:")
print(idx_subj)
print(proposal_idx_pairs)
对应输出结果理解
再针对源代码举例
# 示例数据
proposals_per_image_labels = torch.tensor([1, 2, 3])
# 假设 proposals_per_image.get_field('labels') 返回形状为 (3,) 的张量,表示每个目标区域的类别标签
labels = proposals_per_image_labels
# 生成目标对的索引
idx_subj = torch.tensor([[[0], [0], [0]], [[1], [1], [1]], [[2], [2], [2]]])
idx_obj = torch.tensor([[[0], [1], [2]], [[0], [1], [2]],[[0], [1], [2]]])
# 获取主语目标和客体目标的类别标签
label_subj = labels[idx_subj]
label_obj = labels[idx_obj]
# 将类别标签连接成 proposal_label_pairs
proposal_label_pairs = torch.cat((label_subj.view(-1, 1), label_obj.view(-1, 1)), 1)
# 生成目标对的索引
idx_subj_flat = idx_subj.view(-1)
idx_obj_flat = idx_obj.view(-1)
proposal_idx_pairs = torch.cat((idx_subj_flat.view(-1, 1), idx_obj_flat.view(-1, 1)), 1)
# 筛选不满足条件的目标对索引
keep_idx = (proposal_idx_pairs[:, 0] != proposal_idx_pairs[:, 1]).nonzero(as_tuple=False).view(-1)
# 打印结果
print("Original Labels:")
print(labels)
print("\nIndex Pairs for Subject Labels:")
print(idx_subj)
print("\nIndex Pairs for Object Labels:")
print(idx_obj)
print("\nLabels for Subject Objects:")
print(label_subj)
print("\nLabels for Object Objects:")
print(label_obj)
print("\nConcatenated Label Pairs:")
print(proposal_label_pairs)
print("\nFiltered Index Pairs:")
print(proposal_idx_pairs)
print("\nFiltered Index for Keeping Pairs:")
print((proposal_idx_pairs[:, 0] != proposal_idx_pairs[:, 1]).nonzero(as_tuple=False))
print(keep_idx)
最后输出结果
Original Labels:
tensor([1, 2, 3])
Index Pairs for Subject Labels:
tensor([[[0],
[0],
[0]],
[[1],
[1],
[1]],
[[2],
[2],
[2]]])
Index Pairs for Object Labels:
tensor([[[0],
[1],
[2]],
[[0],
[1],
[2]],
[[0],
[1],
[2]]])
Labels for Subject Objects:
tensor([[[1],
[1],
[1]],
[[2],
[2],
[2]],
[[3],
[3],
[3]]])
Labels for Object Objects:
tensor([[[1],
[2],
[3]],
[[1],
[2],
[3]],
[[1],
[2],
[3]]])
Concatenated Label Pairs:
tensor([[1, 1],
[1, 2],
[1, 3],
[2, 1],
[2, 2],
[2, 3],
[3, 1],
[3, 2],
[3, 3]])
Filtered Index Pairs:
tensor([[0, 0],
[0, 1],
[0, 2],
[1, 0],
[1, 1],
[1, 2],
[2, 0],
[2, 1],
[2, 2]])
Filtered Index for Keeping Pairs:
tensor([[1],
[2],
[3],
[5],
[6],
[7]])
tensor([1, 2, 3, 5, 6, 7])