relation_head

看源代码的学习笔记

源代码

def _get_proposal_pairs(self, proposals):
        proposal_pairs = []
        for i, proposals_per_image in enumerate(proposals):
            box_subj = proposals_per_image.bbox
            box_obj = proposals_per_image.bbox
            
            #unsqueeze()的作用是用来增加给定tensor的维度的,指定增加1维
            #squeeze()的作用就是压缩维度,直接把维度为1的维给去掉。形式上表现为,去掉一层[]括号。
            #直接将repeat参数对应乘上原tensor的shape就得到最终的shape
            #  tensor.view()这个函数有点类似reshape的功能,不想手动的去计算其他的维度值,就可以使用view(-1)
            #  简单的理解就是:先把一个tensor转换成一个一维的tensor,然后再组合成指定维度的tensor。
            #具体计算见biji  2
            box_subj = box_subj.unsqueeze(1).repeat(1, box_subj.shape[0], 1)
            box_obj = box_obj.unsqueeze(0).repeat(box_obj.shape[0], 1, 1)
            proposal_box_pairs = torch.cat(
                (box_subj.view(-1, 4), box_obj.view(-1, 4)), 1)

            idx_subj = torch.arange(box_subj.shape[0]).view(-1, 1, 1).repeat(1, box_obj.shape[0], 1).to(
                proposals_per_image.bbox.device)
            idx_obj = torch.arange(box_obj.shape[0]).view(1, -1, 1).repeat(box_subj.shape[0], 1, 1).to(
                proposals_per_image.bbox.device)
            proposal_idx_pairs = torch.cat((idx_subj.view(-1, 1), idx_obj.view(-1, 1)), 1)   
            
            #具体计算见biji  3
            label_subj = proposals_per_image.get_field('labels')[idx_subj]
            label_obj = proposals_per_image.get_field('labels')[idx_obj]
            proposal_label_pairs = torch.cat(
                (label_subj.view(-1, 1), label_obj.view(-1, 1)), 1)

            keep_idx = (proposal_idx_pairs[:, 0] != proposal_idx_pairs[:, 1]).nonzero(as_tuple=False).view(-1)

            # if we filter non overlap bounding boxes
            if self.cfg.MODEL.ROI_RELATION_HEAD.FILTER_NON_OVERLAP:
                ious = boxlist_iou(proposals_per_image, proposals_per_image).view(-1)
                ious = ious[keep_idx]
                keep_idx = keep_idx[(ious > 0).nonzero(as_tuple=False).view(-1)]
            proposal_idx_pairs = proposal_idx_pairs[keep_idx]
            proposal_box_pairs = proposal_box_pairs[keep_idx]
            proposal_label_pairs = proposal_label_pairs[keep_idx]
            proposal_pairs_per_image = BoxPairList(proposal_box_pairs, proposals_per_image.size, proposals_per_image.mode)
            proposal_pairs_per_image.add_field("idx_pairs", proposal_idx_pairs)
            proposal_pairs_per_image.add_field("label_pairs", proposal_label_pairs)

            proposal_pairs.append(proposal_pairs_per_image)

        return proposal_pairs

举例解释

# 假设 proposals_per_image.bbox 是一个形状为 (3, 4) 的张量
# 每个目标区域由四个值表示(左上角坐标x、y,宽度和高度)
proposals_per_image_bbox = torch.tensor([[0, 0, 2, 2], [1, 1, 3, 3], [2, 2, 4, 4]])

# 复制 proposals_per_image_bbox,得到 box_subj 和 box_obj
box_subj = proposals_per_image_bbox
box_obj = proposals_per_image_bbox

# 对 box_subj 和 box_obj 进行扩展
box_subj = box_subj.unsqueeze(1).repeat(1, box_subj.shape[0], 1)
box_obj = box_obj.unsqueeze(0).repeat(box_obj.shape[0], 1, 1)

# 打印 box_subj 的值
print("box_subj:")
print(box_subj)
print(box_subj.view(-1, 4))

# 将扩展后的 box_subj 和 box_obj 连接成 proposal_box_pairs
proposal_box_pairs = torch.cat((box_subj.view(-1, 4), box_obj.view(-1, 4)), 1)

# 生成目标对的索引
idx_subj = torch.arange(box_subj.shape[0]).view(-1, 1, 1).repeat(1, box_obj.shape[0], 1).to(proposals_per_image_bbox.device)
print("torch.arange(box_subj.shape[0]).view(-1, 1, 1)")
print(torch.arange(box_subj.shape[0]).view(-1, 1, 1))
print(idx_subj)
idx_obj = torch.arange(box_obj.shape[0]).view(1, -1, 1).repeat(box_subj.shape[0], 1, 1).to(proposals_per_image_bbox.device)
print("torch.arange(box_obj.shape[0]).view(1, -1, 1)")
print(torch.arange(box_obj.shape[0]).view(1, -1, 1))
print(idx_obj)
proposal_idx_pairs = torch.cat((idx_subj.view(-1, 1), idx_obj.view(-1, 1)), 1)

# 打印结果
print("\nOriginal Bounding Boxes:")
print(proposals_per_image_bbox)

print("\nExpanded and Concatenated Bounding Box Pairs:")
print(proposal_box_pairs)

print("\nIndex Pairs for Bounding Box Pairs:")
print(idx_subj)
print(proposal_idx_pairs)

对应输出结果理解

再针对源代码举例

# 示例数据
proposals_per_image_labels = torch.tensor([1, 2, 3])

# 假设 proposals_per_image.get_field('labels') 返回形状为 (3,) 的张量,表示每个目标区域的类别标签
labels = proposals_per_image_labels

# 生成目标对的索引
idx_subj = torch.tensor([[[0], [0], [0]], [[1], [1], [1]], [[2], [2], [2]]])
idx_obj = torch.tensor([[[0], [1], [2]], [[0], [1], [2]],[[0], [1], [2]]])

# 获取主语目标和客体目标的类别标签
label_subj = labels[idx_subj]
label_obj = labels[idx_obj]

# 将类别标签连接成 proposal_label_pairs
proposal_label_pairs = torch.cat((label_subj.view(-1, 1), label_obj.view(-1, 1)), 1)

# 生成目标对的索引
idx_subj_flat = idx_subj.view(-1)
idx_obj_flat = idx_obj.view(-1)
proposal_idx_pairs = torch.cat((idx_subj_flat.view(-1, 1), idx_obj_flat.view(-1, 1)), 1)

# 筛选不满足条件的目标对索引
keep_idx = (proposal_idx_pairs[:, 0] != proposal_idx_pairs[:, 1]).nonzero(as_tuple=False).view(-1)

# 打印结果
print("Original Labels:")
print(labels)

print("\nIndex Pairs for Subject Labels:")
print(idx_subj)

print("\nIndex Pairs for Object Labels:")
print(idx_obj)

print("\nLabels for Subject Objects:")
print(label_subj)

print("\nLabels for Object Objects:")
print(label_obj)

print("\nConcatenated Label Pairs:")
print(proposal_label_pairs)

print("\nFiltered Index Pairs:")
print(proposal_idx_pairs)

print("\nFiltered Index for Keeping Pairs:")
print((proposal_idx_pairs[:, 0] != proposal_idx_pairs[:, 1]).nonzero(as_tuple=False))
print(keep_idx)

最后输出结果

Original Labels:
tensor([1, 2, 3])

Index Pairs for Subject Labels:
tensor([[[0],
         [0],
         [0]],

        [[1],
         [1],
         [1]],

        [[2],
         [2],
         [2]]])

Index Pairs for Object Labels:
tensor([[[0],
         [1],
         [2]],

        [[0],
         [1],
         [2]],

        [[0],
         [1],
         [2]]])

Labels for Subject Objects:
tensor([[[1],
         [1],
         [1]],

        [[2],
         [2],
         [2]],

        [[3],
         [3],
         [3]]])

Labels for Object Objects:
tensor([[[1],
         [2],
         [3]],

        [[1],
         [2],
         [3]],

        [[1],
         [2],
         [3]]])

Concatenated Label Pairs:
tensor([[1, 1],
        [1, 2],
        [1, 3],
        [2, 1],
        [2, 2],
        [2, 3],
        [3, 1],
        [3, 2],
        [3, 3]])

Filtered Index Pairs:
tensor([[0, 0],
        [0, 1],
        [0, 2],
        [1, 0],
        [1, 1],
        [1, 2],
        [2, 0],
        [2, 1],
        [2, 2]])

Filtered Index for Keeping Pairs:
tensor([[1],
        [2],
        [3],
        [5],
        [6],
        [7]])
tensor([1, 2, 3, 5, 6, 7])

  • 5
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值