self.spatial_pos_encoder = nn.Embedding(num_spatial, num_heads, padding_idx=0)
# spatial pos
# [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2)
graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + spatial_pos_bias
spatial_pos为节点到每个节点的跳数,数据原始维度为(graph,node,node)
nn.Embedding会将向量扩充一维,将跳数从one-hot到向量化