【因果学习】VC RCNN(CVPR 2020)代码

作者基于MaskRCNN框架(Detectron2的前身)开发。受Bottom-Up and Top-Down Attention for Image Captioning and VQA启发,使用Mask RCNN作为Bottom-Up的backbone,为Downstream任务例如Image Captioning、VQA等提供图片特征。

论文中提到,去掉了RPN,使用GT bbox作为输入,训练的损失修改为:

 测试阶段,则变为特征提取阶段,通过ROI_HEAD输出的特征,认为是VC Feature。

配置文件在:e2e_mask_rcnn_R_101_FPN_1x.yaml,相较于MaskRCNN,作者的BASE_LR从0.02修改为0.005,MAX_ITERS从90k修改为240k,同时作者是从头训练。主要涉及的文件在:ROI_BOX_HEAD中:FPN2MLPFeatureExtrator和FPNPredictor,其中前者是ROI Align和flatten + 两层fc+relu,输出1024维特征。后者FPNPredictor则是class预测和box回归。

 

在box_head.py的ROIBoxHead()中增加了causal_predictor和feature_save_path。

对于predictor(),去掉了box_regression部分,只对class进行分类。用class_logits和class_logits_causal_list送入loss_evaluator(),并在测试阶段,执行save_object_feature_gt_bu()。

在roi_box_predictors.py中增加了CausalPredictor()

@registry.ROI_BOX_PREDICTOR.register("CausalPredictor")
class CausalPredictor(nn.Module):
    def __init__(self, cfg, in_channels):
        super(CausalPredictor, self).__init__()

        num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES
        self.embedding_size = cfg.MODEL.ROI_BOX_HEAD.EMBEDDING
        representation_size = in_channels

        self.causal_score = nn.Linear(2*representation_size, num_classes)
        self.Wy = nn.Linear(representation_size, self.embedding_size)
        self.Wz = nn.Linear(representation_size, self.embedding_size)

        nn.init.normal_(self.causal_score.weight, std=0.01)
        nn.init.normal_(self.Wy.weight, std=0.02)
        nn.init.normal_(self.Wz.weight, std=0.02)
        nn.init.constant_(self.Wy.bias, 0)
        nn.init.constant_(self.Wz.bias, 0)
        nn.init.constant_(self.causal_score.bias, 0)

        self.feature_size = representation_size
        self.dic = torch.tensor(np.load(cfg.DIC_FILE)[1:], dtype=torch.float)
        self.prior = torch.tensor(np.load(cfg.PRIOR_PROB), dtype=torch.float)

    def forward(self, x, proposals):
        device = x.get_device()
        dic_z = self.dic.to(device)
        prior = self.prior.to(device)

        box_size_list = [proposal.bbox.size(0) for proposal in proposals]
        feature_split = x.split(box_size_list)
        xzs = [self.z_dic(feature_pre_obj, dic_z, prior) for feature_pre_obj in feature_split]

        causal_logits_list = [self.causal_score(xz) for xz in xzs]


        return causal_logits_list


    def z_dic(self, y, dic_z, prior):
        """
        Please note that we computer the intervention in the whole batch rather than for one object in the main paper.
        """
        length = y.size(0)
        if length == 1:
            print('debug')
        attention = torch.mm(self.Wy(y), self.Wz(dic_z).t()) / (self.embedding_size ** 0.5)
        attention = F.softmax(attention, 1)
        z_hat = attention.unsqueeze(2) * dic_z.unsqueeze(0)
        z = torch.matmul(prior.unsqueeze(0), z_hat).squeeze(1)
        xz = torch.cat((y.unsqueeze(1).repeat(1, length, 1), z.unsqueeze(0).repeat(length, 1, 1)), 2).view(-1, 2*y.size(1))

        # detect if encounter nan
        if torch.isnan(xz).sum():
            print(xz)
        return xz

在loss.py中修改了FastRCNNLossComputation()中的__call__函数

    def __call__(self, class_logits, causal_logits_list, proposals):
        """
        Computes the loss for Faster R-CNN.
        This requires that the subsample method has been called beforehand.

        Arguments:
            class_logits (list[Tensor])
            box_regression (list[Tensor])

        Returns:
            classification_loss (Tensor)
            box_loss (Tensor)
        """

        class_logits = cat(class_logits, dim=0)
        device = class_logits.device

        labels = [proposal.get_field("labels").to(dtype=torch.int64) for proposal in proposals]
        labels_self = cat(labels, dim=0)

        # self predictor loss
        classification_loss = F.cross_entropy(class_logits, labels_self)

        # context predictor loss
        causal_loss = 0.
        for causal_logit, label in zip(causal_logits_list, labels):
            mask_label = label.unsqueeze(0).repeat(label.size(0), 1)
            mask = 1 - torch.eye(mask_label.size(0)).to(device)
            loss_causal = F.cross_entropy(causal_logit, mask_label.view(-1), reduction='none')


            loss_causal = loss_causal * mask.view(-1)
            causal_loss += torch.mean(loss_causal)

        return classification_loss, causal_loss

在box_head.py的ROIBoxHead()中增加函数,用于在测试中,保存feature

    def save_object_feature_gt_bu(self, x, result, targets):

        for i, image in enumerate(result):
            feature_pre_image = image.get_field("features").cpu().numpy()
            try:
                assert image.get_field("num_box")[0] == feature_pre_image.shape[0]
                image_id = str(image.get_field("image_id")[0].cpu().numpy())
                path = os.path.join(self.feature_save_path, image_id) +'.npy'
                np.save(path, feature_pre_image)
            except:
                print(image)

总的来说,作者去掉了和bbox相关的所有部分,本文使用的Mask R-CNN测试时需要提供bbox GT,某种程度上来说,它只执行了分类任务,并不包含任何的定位信息,因此不能单独使用,必须要加上Up-Down feature。

不同于Up-Down那篇论文的Faster RCNN是可以用于目标检测任务的。

从论文的测试结果也可以看出,Only VC的效果是比Origin要低的。

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值