Sparse4Dv3 代码学习(Ⅲ)时序多帧推理

上一篇文章Sparse4Dv3 代码学习(Ⅱ)单帧推理-CSDN博客介绍了单帧,也就是序列的第一帧的推理过程,这篇文章主要介绍引入历史帧推理时的处理过程。

①InstanceBank,主要是把缓存的anchor投影到当前帧

        # ========= get instance info ============
        if (
            self.sampler.dn_metas is not None
            and self.sampler.dn_metas["dn_anchor"].shape[0] != batch_size
        ):  # 第一帧不进入
            self.sampler.dn_metas = None
        (
            instance_feature,
            anchor,
            temp_instance_feature,
            temp_anchor,
            time_interval,
        ) = self.instance_bank.get(
            batch_size, metas, dn_metas=self.sampler.dn_metas
        )

第一个很不一样的是实例库,这个时候实例库缓存了历史帧的实例特征和anchor

    def get(self, batch_size, metas=None, dn_metas=None):
        instance_feature = torch.tile(
            self.instance_feature[None], (batch_size, 1, 1) # self.instance_feature: torch.Size([900, 256])
        )   # torch.Size([1, 900, 256])
        anchor = torch.tile(self.anchor[None], (batch_size, 1, 1))  # torch.Size([1, 900, 11])

        if (
            self.cached_anchor is not None
            and batch_size == self.cached_anchor.shape[0]
        ):  # 第一帧时不进入
            history_time = self.metas["timestamp"]
            time_interval = metas["timestamp"] - history_time   # tensor([0.4999], device='cuda:0', dtype=torch.float64)
            time_interval = time_interval.to(dtype=instance_feature.dtype)
            self.mask = torch.abs(time_interval) <= self.max_time_interval  # 0.5<2

第一步是把anchor转换到当前帧:

            if self.anchor_handler is not None:
                T_temp2cur = self.cached_anchor.new_tensor(
                    np.stack(
                        [
                            x["T_global_inv"]
                            @ self.metas["img_metas"][i]["T_global"]
                            for i, x in enumerate(metas["img_metas"])
                        ]
                    )
                )   # torch.Size([1, 4, 4])
                self.cached_anchor = self.anchor_handler.anchor_projection(
                    self.cached_anchor,
                    [T_temp2cur],
                    time_intervals=[-time_interval],
                )[0]

里面的投影的细节,包括中心点(center)的投影和速度(vel)的转换:

    @staticmethod   # 这应该是帧间anchor的投影转换
    def anchor_projection(
        anchor,
        T_src2dst_list,
        src_timestamp=None,
        dst_timestamps=None,
        time_intervals=None,
    ):
        dst_anchors = []
        for i in range(len(T_src2dst_list)):
            vel = anchor[..., VX:]
            vel_dim = vel.shape[-1]
            T_src2dst = torch.unsqueeze(
                T_src2dst_list[i].to(dtype=anchor.dtype), dim=1
            )   # torch.Size([1, 1, 4, 4])

            center = anchor[..., [X, Y, Z]]
            if time_intervals is not None:  
                time_interval = time_intervals[i]   # tensor([-0.4999], device='cuda:0')
            elif src_timestamp is not None and dst_timestamps is not None:
                time_interval = (src_timestamp - dst_timestamps[i]).to(
                    dtype=vel.dtype
                )
            else:
                time_interval = None
            if time_interval is not None:
                translation = vel.transpose(0, -1) * time_interval
                translation = translation.transpose(0, -1)
                center = center - translation
            center = (
                torch.matmul(
                    T_src2dst[..., :3, :3], center[..., None]
                ).squeeze(dim=-1)
                + T_src2dst[..., :3, 3]
            )
            size = anchor[..., [W, L, H]]
            yaw = torch.matmul(
                T_src2dst[..., :2, :2],
                anchor[..., [COS_YAW, SIN_YAW], None],
            ).squeeze(-1)
            vel = torch.matmul(
                T_src2dst[..., :vel_dim, :vel_dim], vel[..., None]
            ).squeeze(-1)
            dst_anchor = torch.cat([center, size, yaw, vel], dim=-1)
            # TODO: Fix bug
            # index = [X, Y, Z, W, L, H, COS_YAW, SIN_YAW] + [VX, VY, VZ][:vel_dim]
            # index = torch.tensor(index, device=dst_anchor.device)
            # index = torch.argsort(index)
            # dst_anchor = dst_anchor.index_select(dim=-1, index=index)
            dst_anchors.append(dst_anchor)
        return dst_anchors

②历史anchor的编码

        if temp_anchor is not None:
            temp_anchor_embed = self.anchor_encoder(temp_anchor)

 ③在refine里面会用到历史anchor的编码temp_anchor_embed:

现在只需要挑出300个anchor:

        N = self.num_anchor - self.num_temp_instances   # 900-600等于300
        confidence = confidence.max(dim=-1).values
        _, (selected_feature, selected_anchor) = topk(
            confidence, N, instance_feature, anchor
        )

根据选择出来的300个anchor合并到缓存的600个anchor一起(包括anchor、实例特征、实例ID):

        selected_feature = torch.cat(
            [self.cached_feature, selected_feature], dim=1
        )
        selected_anchor = torch.cat(
            [self.cached_anchor, selected_anchor], dim=1
        )
        instance_feature = torch.where(
            self.mask[:, None, None], selected_feature, instance_feature
        )   # 因为self.mask=True,所以实际上都选择的是selected_feature
        anchor = torch.where(self.mask[:, None, None], selected_anchor, anchor)
        if self.instance_id is not None:
            self.instance_id = torch.where(
                self.mask[:, None],
                self.instance_id,
                self.instance_id.new_tensor(-1),
            )

④在gnn里面会用到时间信息

            elif op == "temp_gnn":
                instance_feature = self.graph_model(
                    i,
                    instance_feature,   # q
                    temp_instance_feature,  # k
                    temp_instance_feature,  # v
                    query_pos=anchor_embed,
                    key_pos=temp_anchor_embed,
                    attn_mask=attn_mask
                    if temp_instance_feature is None
                    else None,
                )
    def graph_model(
        self,
        index,
        query,
        key=None,
        value=None,
        query_pos=None,
        key_pos=None,
        **kwargs,
    ):
        if self.decouple_attn:  # 进入
            query = torch.cat([query, query_pos], dim=-1)   # torch.Size([1, 900, 512])
            if key is not None: # 第一帧的gnn_temp不进入  后续会进入
                key = torch.cat([key, key_pos], dim=-1)
            query_pos, key_pos = None, None
        if value is not None:   # temp_gnn不进入  gnn进入
            value = self.fc_before(value)   # torch.Size([1, 900, 512]) torch.Size([1, 600, 512])
        return self.fc_after(
            self.layers[index](
                query,
                key,
                value,
                query_pos=query_pos,
                key_pos=key_pos,
                **kwargs,
            )
        )

④继续缓存600个:

    def cache(
        self,
        instance_feature,
        anchor,
        confidence,
        metas=None,
        feature_maps=None,
    ):
        if self.num_temp_instances <= 0:
            return
        instance_feature = instance_feature.detach()
        anchor = anchor.detach()
        confidence = confidence.detach()

        self.metas = metas
        confidence = confidence.max(dim=-1).values.sigmoid()
        if self.confidence is not None:
            confidence[:, : self.num_temp_instances] = torch.maximum(
                self.confidence * self.confidence_decay,    # self.confidence_decay=0.6
                confidence[:, : self.num_temp_instances],   # 选前600
            )
        self.temp_confidence = confidence

        (
            self.confidence,
            (self.cached_feature, self.cached_anchor),
        ) = topk(confidence, self.num_temp_instances, instance_feature, anchor)

 

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: 稀疏贝叶斯学习Sparse Bayesian Learning)是一种机器学习方法,用于估计线性模型中的参数。该方法通过在参数的先验概率分布中引入稀疏性的假设,从而得到稀疏解。稀疏解可以帮助我们更好地理解数据,并提高模型的泛化能力。 稀疏贝叶斯学习代码实现可以按照以下步骤进行: 1. 加载所需的库和数据集:加载用于稀疏贝叶斯学习的库,如NumPy和SciPy。加载数据集,并将其分为训练集和测试集。 2. 定义稀疏模型:使用贝叶斯公式和朴素贝叶斯假设,定义稀疏模型的先验和似然函数。先验函数通常使用Laplace先验或高斯先验,并通过调整超参数来控制稀疏性。 3. 定义优化问题:将稀疏模型转化为一个优化问题,以最小化损失函数。常见的损失函数包括最大似然估计、最小二乘法等。 4. 确定超参数:通过交叉验证或贝叶斯优化等方法,确定超参数的最佳取值。超参数包括先验函数的超参数和优化问题的参数,如正则化参数、学习率等。 5. 优化模型:使用优化算法(如梯度下降、共轭梯度等)迭代地调整参数,以最小化损失函数。在每次迭代中,通过更新规则更新参数,并使用先验函数对参数进行修剪,以保持稀疏性。 6. 评估模型:使用训练好的模型对测试集进行预测,并计算预测结果的准确率或其他性能指标。如果模型性能不满足要求,可以回到步骤4,重新选择超参数。 稀疏贝叶斯学习代码实现不仅限于上述步骤,还取决于具体的实现框架和程序设计。有多种工具和软件包可以用于实现稀疏贝叶斯学习,如Scikit-learn、TensorFlow等。根据所选框架的不同,代码实现可能有所差异,但总的思路和方法是相似的。 ### 回答2: 稀疏贝叶斯学习Sparse Bayesian Learning)是一种用于构建稀疏模型的机器学习方法。其主要思想是通过贝叶斯统计推断来自适应地确定模型的参数。 Sparse Bayesian Learning的代码实现通常包含以下几个步骤: 1. 数据处理:首先,需要将所需要的数据进行预处理。根据实际问题的要求,通常会进行数据清洗、归一化或者特征选择等操作。 2. 参数初始化:然后,需要对模型的参数进行初始化。一般而言,可以采用随机初始化的方式来赋初值。 3. 贝叶斯推断:接下来,通过贝叶斯推断的方法,根据观测数据来更新模型的参数。具体而言,可以采用变分贝叶斯(Variational Bayes)或马尔可夫链蒙特卡洛(Markov chain Monte Carlo)等方法来进行推断。 4. 条件概率计算:随后,根据推断得到的后验分布,可以计算得到参数的条件概率分布,进而用于模型的测试或预测。 5. 模型选择:最后,需要通过模型选择的方法,如最大后验估计(MAP)或正则化方法等,对模型的结构进行优化和筛选,以达到稀疏模型的目的。 需要注意的是,Sparse Bayesian Learning的代码实现会涉及到概率计算、数值优化、矩阵运算等复杂的数学和算法操作。因此在实际编写代码时,需要使用适当的编程工具和数学库,并仔细考虑算法的复杂度和效率。此外,代码中还需要进行适当的验证,以保证模型的准确性和可靠性。 ### 回答3: sparse bayesian learning(稀疏贝叶斯学习)是一种机器学习算法,旨在通过最小化预测误差和对模型假设的复杂度进行特征选择和模型参数估计。 sparse bayesian learning 代码实现主要包括以下步骤: 1. 数据预处理:将输入数据集进行标准化处理,以确保不同特征具有相同的尺度。 2. 初始化模型参数:初始化模型参数,如稀疏先验超参数和噪声方差。 3. 迭代训练:采用变分贝叶斯方法进行模型参数和特征选择的迭代更新。 4. E步(Expectation Step):使用当前模型参数估计每个数据点的后验概率。 5. M步(Maximization Step):根据数据点的后验概率更新模型参数。 6. 收敛判断:计算当前模型参数的对数似然函数,若变化小于设定阈值,则认为算法已经收敛,停止迭代。 7. 特征选择:基于模型参数的后验概率,选择具有高概率的特征作为最终的特征子集。 8. 预测:使用更新后的模型参数进行新数据点的预测,通过计算后验概率或对数似然函数来判断分类或回归问题的性能。 总之,sparse bayesian learning代码实现的关键在于迭代更新模型参数和特征选择过程,通过极大似然估计和模型复杂度的惩罚项来实现稀疏性。此算法在处理高维数据时具有优势,能够自动选择相关特征,提高模型的泛化性能。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值