将图神经网络应用于多目标跟踪的基本框架

1、MOT为什么要用图神经网络

深度学习擅长捕捉欧几里得数据(图像、文本、视频)的隐藏模式。但是,如果是从非欧几里的域生成数据,对象之间具有复杂关系和相互依赖关系的图形就需要GNN

目标检测数据关联是多目标跟踪(MOT)系统的关键组成部分,对于目标检测来说,有很多种方法进行解决,有时候会应用GCN, 但GNN一般还是用来解决数据关联问题,数据关联问题的核心难点在于:1)遮挡给物体帧与帧之间的数据关联带来不小的挑战;2)长时间下的数据关联也会遇到困难

此外,使用普通的数据关联方法难以兼顾轨迹和帧内检测之间的上下文信息,而且忽略了对象-对象关系,这种对象与对象之间的相互作用对于数据关联来说是有不小意义的

现阶段在使用 GNN 的MOT算法中,都需要对图进行构建,其构建方法都大同小异,通过构建一个图 G(V,E) ,来对数据关联问题进行建模。 其中节点 V 是检测和轨迹的特征,边 E 表示节点即对象之间的相互作用,如此就将当前的运行状态进行了可视化,而在构建图后,对节点、边缘的更新模块也是必不可少,实现了静态图的不断更新

本文将以将GNN求解的最小流算法应用于多目标跟踪-CSDN博客为例阐述其和图有关的代码框架

2、图的构建

2.1、获取图数据

    def get_from_frame_and_seq(self, seq_name, start_frame, max_frame_dist, end_frame=None, ensure_end_is_in=False,
                               return_full_object=False, inference_mode=False):
        """
        __getitem__ 方法的底层方法。从给定的序列名称加载一个图对象,从 'start_frame' 开始。
        Args:
            seq_name: 指示要从中获取图的场景的字符串
            start_frame: 指示图应该从哪一帧开始的整数帧数
            end_frame: 指示图应该结束的整数帧数(可选)
            ensure_end_is_in: bool,指示是否需要 end_frame 在图中
            return_full_object: bool,指示是否需要整个 MOTGraph 对象
            还是只需要其 Graph 对象(图网络的输入)
        Returns:
            mot_graph: 输出 MOTGraph 对象或 Graph 对象,
            取决于 return_full_object == True 还是 False

        """
        seq_det_df = self.seq_det_dfs[seq_name]
        seq_info_dict = self.seq_info_dicts[seq_name]
        seq_step_size = self.seq_info_dicts[seq_name]['step_size']

        # 如果进行数据增强,则随机更改处理场景的 fps 率
        if self.mode == 'train' and self.augment and seq_step_size > 1:
            if np.random.rand() < self.dataset_params['p_change_fps_step']:
                seq_step_size = np.round(seq_step_size * (0.5 + np.random.rand())).astype(int)

        mot_graph = MOTGraph(dataset_params=self.dataset_params,
                             seq_info_dict=seq_info_dict,
                             seq_det_df=seq_det_df,
                             step_size=seq_step_size,
                             start_frame=start_frame,
                             end_frame=end_frame,
                             ensure_end_is_in=ensure_end_is_in,
                             max_frame_dist=max_frame_dist,
                             cnn_model=self.cnn_model,
                             inference_mode=inference_mode)

        if self.mode == 'train' and self.augment:
            mot_graph.augment()

        # 构造图网络的输入
        mot_graph.construct_graph_object()
        if self.mode in ('train', 'val'):
            mot_graph.assign_edge_labels()

        if return_full_object:
            return mot_graph
        else:
            return mot_graph.graph_obj

MOTGraph

class MOTGraph(object):
    """
    这是用来从检测(可能还包括GT)文件中创建 MOT 图的主要类。其主要属性是 'graph_obj',
    它是 'Graph' 类的一个实例,用作跟踪模型的输入。
    此外,每个 'MOTGraph' 还具有几个额外的属性,提供有关从中构建图的帧子集中的检测的进一步信息。
    """
    def __init__(self, seq_det_df=None, start_frame=None, end_frame=None, ensure_end_is_in=False, step_size=None,
                 seq_info_dict=None, dataset_params=None, inference_mode=False, cnn_model=None, max_frame_dist=None):
        """
        初始化方法。
        Args:
            seq_det_df: DataFrame,包含检测信息。
            start_frame: int,开始帧。
            end_frame: int,结束帧。
            ensure_end_is_in: bool,指示结束帧是否需要在图中。
            step_size: int,帧的步长。
            seq_info_dict: dict,序列信息字典。
            dataset_params: dict,数据集参数。
            inference_mode: bool,推断模式标志。
            cnn_model: CNN 模型,用于图像特征提取。
            max_frame_dist: int,最大帧距离。
        """
        self.dataset_params = dataset_params
        self.step_size = step_size
        self.seq_info_dict = seq_info_dict
        self.inference_mode = inference_mode
        self.max_frame_dist = max_frame_dist

        self.cnn_model = cnn_model

        if seq_det_df is not None:
            self.graph_df, self.frames = self._construct_graph_df(seq_det_df=seq_det_df.copy(),
                                                                      start_frame=start_frame,
                                                                          end_frame=end_frame,
                                                                  ensure_end_is_in=ensure_end_is_in)

2.2、构造图网络的输入

def construct_graph_object(self):
    """
    构建整个图对象以作为 MPN 的输入,并将其存储在 self.graph_obj 中
    """
    # 加载外观数据
    reid_embeddings, node_feats = self._load_appearance_data()  # 加载外观数据:检测信息和嵌入特征

    # 确定图连接性(即边)并计算边特征
    edge_ixs, edge_feats_dict = self._get_edge_ixs(reid_embeddings)  # 获取边的索引和特征
    if edge_feats_dict is None:
        edge_feats_dict = compute_edge_feats_dict(edge_ixs=edge_ixs, det_df=self.graph_df,
                                                  fps=self.seq_info_dict['fps'],
                                                  use_cuda=self.inference_mode)  # 计算边特征
    edge_feats = [edge_feats_dict[feat_names] for feat_names in self.dataset_params['edge_feats_to_use'] if
                  feat_names in edge_feats_dict]  # 提取边特征
    edge_feats = torch.stack(edge_feats).T

    # 计算嵌入距离。成对距离计算可能会导致内存不足的错误,因此我们进行分批处理
    emb_dists = []
    for i in range(0, edge_ixs.shape[1], 50000):
        emb_dists.append(F.pairwise_distance(reid_embeddings[edge_ixs[0][i:i + 50000]],
                                             reid_embeddings[edge_ixs[1][i:i + 50000]]).view(-1, 1))  # 计算嵌入距离
    emb_dists = torch.cat(emb_dists, dim=0)

    # 如果需要,将嵌入距离添加到边特征中
    if 'emb_dist' in self.dataset_params['edge_feats_to_use']:
        edge_feats = torch.cat((edge_feats, emb_dists), dim=1)

    # 构建图对象,大多用来更改数据类型
    self.graph_obj = Graph(x=node_feats,
                           edge_attr=torch.cat((edge_feats, edge_feats), dim=0),
                           edge_index=torch.cat((edge_ixs, torch.stack((edge_ixs[1], edge_ixs[0]))), dim=1))

    if self.inference_mode:
        self.graph_obj.reid_emb_dists = torch.cat((emb_dists, emb_dists))  # 存储嵌入距离

    self.graph_obj.to(torch.device("cuda" if torch.cuda.is_available() and self.inference_mode else "cpu")))  # 将图对象移动到GPU或CPU上

外观特征提取

    def _load_appearance_data(self):
        """
        加载节点特征和重新识别的嵌入。
        Returns:
            包含(重新识别嵌入,节点特征)的元组,都是形状为(num_nodes,embed_dim)的 torch.tensors
        """
        if self.inference_mode and not self.dataset_params['precomputed_embeddings']:
            assert self.cnn_model is not None
            print("使用 CNN 进行外观特征提取")
            _, node_feats, reid_embeds = load_embeddings_from_imgs(det_df=self.graph_df,
                                                                   dataset_params=self.dataset_params,
                                                                   seq_info_dict=self.seq_info_dict,
                                                                   cnn_model=self.cnn_model,
                                                                   return_imgs=False,
                                                                   use_cuda=self.inference_mode)

        else:
            reid_embeds = load_precomputed_embeddings(det_df=self.graph_df,
                                                      seq_info_dict=self.seq_info_dict,
                                                      embeddings_dir=self.dataset_params['reid_embeddings_dir'],
                                                      use_cuda=self.inference_mode)
            if self.dataset_params['reid_embeddings_dir'] == self.dataset_params['node_embeddings_dir']:
                node_feats = reid_embeds.clone()

            else:
                node_feats = load_precomputed_embeddings(det_df=self.graph_df,
                                                         seq_info_dict=self.seq_info_dict,
                                                         embeddings_dir=self.dataset_params['node_embeddings_dir'],
                                                         use_cuda=self.inference_mode)

        return reid_embeds, node_feats

获取边的特征数据

    def _get_edge_ixs(self, reid_embeddings):
        """
        通过获取具有有效时间连接的节点对构造图边(不在同一帧内,时间上不太远),并根据重新识别嵌入可能获取 KNN。
        Args:
            reid_embeddings: 形状为(num_nodes,reid_embeds_dim)的 torch.tensor
        Returns:
            形状为(2,num_edges)的 torch.tensor
        """
        edge_ixs = get_time_valid_conn_ixs(frame_num=torch.from_numpy(self.graph_df.frame.values),
                                           max_frame_dist=self.max_frame_dist,
                                           use_cuda=self.inference_mode and self.graph_df['frame_path'].iloc[0].find(
                                               'MOT17-03') == -1)

        edge_feats_dict = None
        if 'max_feet_vel' in self.dataset_params and self.dataset_params[
            'max_feet_vel'] is not None:  # 新参数。我们根据脚速度进行图修剪
            # print("VELOCITY PRUNING")
            edge_feats_dict = compute_edge_feats_dict(edge_ixs=edge_ixs, det_df=self.graph_df,
                                                      fps=self.seq_info_dict['fps'],
                                                      use_cuda=self.inference_mode)

            feet_vel = torch.sqrt(edge_feats_dict['norm_feet_x_dists'] ** 2 + edge_feats_dict['norm_feet_y_dists'] ** 2)
            vel_mask = feet_vel < self.dataset_params['max_feet_vel']
            edge_ixs = edge_ixs.T[vel_mask].T
            for feat_name, feat_vals in edge_feats_dict.items():
                edge_feats_dict[feat_name] = feat_vals[vel_mask]

        # 在推理期间,不应该在此处执行 top k nns,因为它是针对序列块单独计算的
        if not self.inference_mode and self.dataset_params['top_k_nns'] is not None:
            reid_pwise_dist = F.pairwise_distance(reid_embeddings[edge_ixs[0]], reid_embeddings[edge_ixs[1]])
            k_nns_mask = get_knn_mask(pwise_dist=reid_pwise_dist,
                                      edge_ixs=edge_ixs,
                                      num_nodes=self.graph_df.shape[0],
                                      top_k_nns=self.dataset_params['top_k_nns'],
                                      reciprocal_k_nns=self.dataset_params['reciprocal_k_nns'],
                                      symmetric_edges=False,
                                      use_cuda=self.inference_mode)
            edge_ixs = edge_ixs.T[k_nns_mask].T
            if edge_feats_dict is not None:
                for feat_name, feat_vals in edge_feats_dict.items():
                    edge_feats_dict[feat_name] = feat_vals[k_nns_mask]

        return edge_ixs, edge_feats_dict

边赋予标签

def assign_edge_labels(self):
    """
    为 self.graph_obj 的边赋予标签(形状为 (num_edges,) 的张量),标签根据网络流 MOT 公式定义
    """
    ids = torch.as_tensor(self.graph_df.id.values, device=self.graph_obj.edge_index.device)
    per_edge_ids = torch.stack([ids[self.graph_obj.edge_index[0]], ids[self.graph_obj.edge_index[1]]])

    # 找出具有相同 id 的边
    same_id = (per_edge_ids[0] == per_edge_ids[1]) & (per_edge_ids[0] != -1)
    same_ids_ixs = torch.where(same_id)
    same_id_edges = self.graph_obj.edge_index.T[same_id].T

    # 计算时间距离
    time_dists = torch.abs(same_id_edges[0] - same_id_edges[1])

    # 对于每个节点,获取未来(或过去)具有相同 id 的节点中最接近的节点的索引
    future_mask = same_id_edges[0] < same_id_edges[1]
    active_fut_edges = scatter_min(time_dists[future_mask], same_id_edges[0][future_mask], dim=0,
                                   dim_size=self.graph_obj.num_nodes)[1]
    original_node_ixs = torch.cat((same_id_edges[1][future_mask], torch.as_tensor([-1], device=same_id.device)))
    active_fut_edges = original_node_ixs[active_fut_edges]
    fut_edge_is_active = active_fut_edges[same_id_edges[0]] == same_id_edges[1]

    # 对于过去边也是类似的
    past_mask = same_id_edges[0] > same_id_edges[1]
    active_past_edges = scatter_min(time_dists[past_mask], same_id_edges[0][past_mask], dim=0,
                                    dim_size=self.graph_obj.num_nodes)[1]
    original_node_ixs = torch.cat((same_id_edges[1][past_mask], torch.as_tensor([-1], device=same_id.device)))
    active_past_edges = original_node_ixs[active_past_edges]
    past_edge_is_active = active_past_edges[same_id_edges[0]] == same_id_edges[1]

    # 恢复原始 edge_index 张量中活动边的索引
    active_edge_ixs = same_ids_ixs[0][past_edge_is_active | fut_edge_is_active]
    # 为活动边分配标签
    self.graph_obj.edge_labels = torch.zeros_like(same_id, dtype=torch.float)
    self.graph_obj.edge_labels[active_edge_ixs] = 1

2.3、图网络更新

对节点和边创建MLP

class MLPGraphIndependent(nn.Module):
    """
    用于在神经消息传递之前(或之后)对特征进行编码(或分类)的类。
    它由两个MLP组成,一个用于节点,一个用于边,它们分别独立应用于节点和边特征。
    该类基于:https://github.com/deepmind/graph_nets 的tensorflow实现。
    """
    def __init__(self, edge_in_dim=None, node_in_dim=None, edge_out_dim=None, node_out_dim=None,
                 node_fc_dims=None, edge_fc_dims=None, dropout_p=None, use_batchnorm=None):
        super(MLPGraphIndependent, self).__init__()

        # 如果提供了节点输入维度,则创建节点MLP
        if node_in_dim is not None:
            self.node_mlp = MLP(input_dim=node_in_dim, fc_dims=list(node_fc_dims) + [node_out_dim],
                                dropout_p=dropout_p, use_batchnorm=use_batchnorm)
        else:
            self.node_mlp = None

        # 如果提供了边输入维度,则创建边MLP
        if edge_in_dim is not None:
            self.edge_mlp = MLP(input_dim=edge_in_dim, fc_dims=list(edge_fc_dims) + [edge_out_dim],
                                dropout_p=dropout_p, use_batchnorm=use_batchnorm)
        else:
            self.edge_mlp = None

    def forward(self, edge_feats=None, nodes_feats=None):

        # 如果存在节点MLP,则将节点特征传递给节点MLP
        if self.node_mlp is not None:
            out_node_feats = self.node_mlp(nodes_feats)
        else:
            out_node_feats = nodes_feats

        # 如果存在边MLP,则将边特征传递给边MLP
        if self.edge_mlp is not None:
            out_edge_feats = self.edge_mlp(edge_feats)
        else:
            out_edge_feats = edge_feats

        return out_edge_feats, out_node_feats

节点和边更新模型

    def _build_core_MPNet(self, model_params, encoder_feats_dict):
        """
        构建消息传递网络的核心部分:节点更新和边更新模型。
        Args:
            model_params: 包含所有模型超参数的字典
            encoder_feats_dict: 包含初始节点/边编码器的超参数的字典
        """

        # 定义用于从邻接边汇聚节点消息的聚合算子
        node_agg_fn = model_params['node_agg_fn']
        assert node_agg_fn.lower() in ('mean', 'max', 'sum'), "node_agg_fn 只能为'max'、'mean'或'sum'。"

        if node_agg_fn == 'mean':
            node_agg_fn = lambda out, row, x_size: scatter_mean(out, row, dim=0, dim_size=x_size)

        elif node_agg_fn == 'max':
            node_agg_fn = lambda out, row, x_size: scatter_max(out, row, dim=0, dim_size=x_size)[0]

        elif node_agg_fn == 'sum':
            node_agg_fn = lambda out, row, x_size: scatter_add(out, row, dim=0, dim_size=x_size)

        # 定义MPN中使用的所有MLP
        # 对于节点和边,初始编码特征(即self.encoder的输出)可以在每个消息传递步骤之后重新连接或不重新连接。这影响MLP的输入维度
        self.reattach_initial_nodes = model_params['reattach_initial_nodes']
        self.reattach_initial_edges = model_params['reattach_initial_edges']

        edge_factor = 2 if self.reattach_initial_edges else 1
        node_factor = 2 if self.reattach_initial_nodes else 1

        edge_model_in_dim = node_factor * 2 * encoder_feats_dict['node_out_dim'] + edge_factor * encoder_feats_dict[
            'edge_out_dim']
        node_model_in_dim = node_factor * encoder_feats_dict['node_out_dim'] + encoder_feats_dict['edge_out_dim']

        # 定义MPN中使用的所有MLP
        edge_model_feats_dict = model_params['edge_model_feats_dict']
        node_model_feats_dict = model_params['node_model_feats_dict']

        edge_mlp = MLP(input_dim=edge_model_in_dim,
                       fc_dims=edge_model_feats_dict['fc_dims'],
                       dropout_p=edge_model_feats_dict['dropout_p'],
                       use_batchnorm=edge_model_feats_dict['use_batchnorm'])

        flow_in_mlp = MLP(input_dim=node_model_in_dim,
                          fc_dims=node_model_feats_dict['fc_dims'],
                          dropout_p=node_model_feats_dict['dropout_p'],
                          use_batchnorm=node_model_feats_dict['use_batchnorm'])

        flow_out_mlp = MLP(input_dim=node_model_in_dim,
                           fc_dims=node_model_feats_dict['fc_dims'],
                           dropout_p=node_model_feats_dict['dropout_p'],
                           use_batchnorm=node_model_feats_dict['use_batchnorm'])

        node_mlp = nn.Sequential(*[nn.Linear(2 * encoder_feats_dict['node_out_dim'], encoder_feats_dict['node_out_dim']),
                                   nn.ReLU(inplace=True)])

        # 定义MPN中使用的所有MLP
        return MetaLayer(edge_model=EdgeModel(edge_mlp=edge_mlp),
                         node_model=TimeAwareNodeModel(flow_in_mlp=flow_in_mlp,
                                                       flow_out_mlp=flow_out_mlp,
                                                       node_mlp=node_mlp,
                                                       node_agg_fn=node_agg_fn))

消息传递网络的核心,用于节点和边特征的更新

class MetaLayer(torch.nn.Module):
    """
    核心消息传递网络类。从torch_geometric中提取,进行了轻微修改。
    (https://rusty1s.github.io/pytorch_geometric/build/html/modules/nn.html)
    """
    def __init__(self, edge_model=None, node_model=None):
        """
        Args:
            edge_model: 可调用的边更新模型
            node_model: 可调用的节点更新模型
        """
        super(MetaLayer, self).__init__()

        self.edge_model = edge_model
        self.node_model = node_model
        self.reset_parameters()

    def reset_parameters(self):
        # 重置模型参数
        for item in [self.node_model, self.edge_model]:
            if hasattr(item, 'reset_parameters'):
                item.reset_parameters()

    def forward(self, x, edge_index, edge_attr):
        """
        对单个节点和边特征向量进行更新。
        Args:
            x: 节点特征矩阵
            edge_index: 形状为 [2, M] 的张量,其中 M 是边数,表示图邻接矩阵中的非零条目(即边)
            edge_attr: 边特征矩阵(按照 edge_index 排序)
        Returns: 更新后的节点和边特征矩阵
        """
        row, col = edge_index

        # 边更新
        if self.edge_model is not None:
            edge_attr = self.edge_model(x[row], x[col], edge_attr)

        # 节点更新
        if self.node_model is not None:
            x = self.node_model(x, edge_index, edge_attr)

        return x, edge_attr

    def __repr__(self):
        return '{}(edge_model={}, node_model={})'.format(self.__class__.__name__, self.edge_model, self.node_model)

边更新模型

class EdgeModel(nn.Module):
    """
    用于在神经消息传递过程中执行边更新的类。
    """
    def __init__(self, edge_mlp):
        super(EdgeModel, self).__init__()
        self.edge_mlp = edge_mlp

    def forward(self, source, target, edge_attr):
        # 将源节点特征、目标节点特征和边特征拼接起来
        out = torch.cat([source, target, edge_attr], dim=1)
        # 通过边MLP进行处理
        return self.edge_mlp(out)

节点更新模型

class TimeAwareNodeModel(nn.Module):
    """
    用于在神经消息传递期间执行节点更新的类。
    """
    def __init__(self, flow_in_mlp, flow_out_mlp, node_mlp, node_agg_fn):
        super(TimeAwareNodeModel, self).__init__()

        self.flow_in_mlp = flow_in_mlp
        self.flow_out_mlp = flow_out_mlp
        self.node_mlp = node_mlp
        self.node_agg_fn = node_agg_fn

    def forward(self, x, edge_index, edge_attr):
        row, col = edge_index

        # 计算流出流的节点特征更新
        flow_out_mask = row < col
        flow_out_row, flow_out_col = row[flow_out_mask], col[flow_out_mask]
        flow_out_input = torch.cat([x[flow_out_col], edge_attr[flow_out_mask]], dim=1)
        flow_out = self.flow_out_mlp(flow_out_input)
        flow_out = self.node_agg_fn(flow_out, flow_out_row, x.size(0))

        # 计算流入流的节点特征更新
        flow_in_mask = row > col
        flow_in_row, flow_in_col = row[flow_in_mask], col[flow_in_mask]
        flow_in_input = torch.cat([x[flow_in_col], edge_attr[flow_in_mask]], dim=1)
        flow_in = self.flow_in_mlp(flow_in_input)
        flow_in = self.node_agg_fn(flow_in, flow_in_row, x.size(0))

        # 将流入和流出的节点特征更新连接起来
        flow = torch.cat((flow_in, flow_out), dim=1)

        # 通过节点MLP进行处理
        return self.node_mlp(flow)

编码节点和边特征,再进行更新

    def forward(self, data):
        """
        为数据关联问题提供分数解决方案。
        首先,节点和边特征由编码器网络独立编码。然后,通过消息传递网络(self.MPNet)进行固定步骤的迭代“组合”。
        最后,它们由分类器网络独立进行分类。
        Args:
            data: 包含属性的对象
              - x: 节点特征矩阵
              - edge_index: 形状为[2,M]的张量,其中M是边数,表示图邻接矩阵中的非零条目(即边)(即稀疏邻接矩阵)
              - edge_attr: 边特征矩阵(按照在edge_index中出现的边缘排序)
        Returns:
            classified_edges: 经过每个MP步骤后的未归一化节点概率列表
        """
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        x_is_img = len(x.shape) == 4
        if self.node_cnn is not None and x_is_img:
            # 图像处理:如果存在节点CNN并且输入是图像
            x = self.node_cnn(x)

            # 计算节点特征之间的欧氏距离,并将其添加到边缘特征中
            emb_dists = nn.functional.pairwise_distance(x[edge_index[0]], x[edge_index[1]]).view(-1, 1)
            edge_attr = torch.cat((edge_attr, emb_dists), dim=1)

        # 编码特征步骤
        latent_edge_feats, latent_node_feats = self.encoder(edge_attr, x)
        initial_edge_feats = latent_edge_feats
        initial_node_feats = latent_node_feats

        # 在训练期间,用于最后几步的MPNetwork输出的特征向量进行分类以计算损失。
        first_class_step = self.num_enc_steps - self.num_class_steps + 1
        outputs_dict = {'classified_edges': []}
        for step in range(1, self.num_enc_steps + 1):

            # 在更新之前重新连接最初编码的嵌入向量
            if self.reattach_initial_edges:
                latent_edge_feats = torch.cat((initial_edge_feats, latent_edge_feats), dim=1)
            if self.reattach_initial_nodes:
                latent_node_feats = torch.cat((initial_node_feats, latent_node_feats), dim=1)

            # 消息传递步骤
            latent_node_feats, latent_edge_feats = self.MPNet(latent_node_feats, edge_index, latent_edge_feats)

            if step >= first_class_step:
                # 分类步骤
                dec_edge_feats, _ = self.classifier(latent_edge_feats)
                outputs_dict['classified_edges'].append(dec_edge_feats)

        if self.num_enc_steps == 0:
            dec_edge_feats, _ = self.classifier(latent_edge_feats)
            outputs_dict['classified_edges'].append(dec_edge_feats)

        return outputs_dict

  • 5
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值