Abstract
3D场景图生成主要有两个方面的挑战:
- 与二维图像相比,三维点云只能捕捉到语义有限的几何结构。
- 长尾的关系分布根本上阻碍了无偏的场景图生成。
因此本文的核心思想是**用一个cross-model的模型来帮助3D上的场景图生成。**这种结构可以在模型的训练阶段从2D、语言和3D几何结构中获取到对模型有益的语义信息。将2D和语言的语义信息异质的传输到3D的模型当中。
该模型可以有效的提高现有模型的效果,同时在推理阶段只需要输入3D的数据即可。
实验发现,该模型可以有效的提高以往的场景图生成模型的表现,例如SGFN和SGG_point。
Introduction
背景
所谓3D场景图,就是在3D点云中对物体的标签进行预测,然后生成不同物体之间的关系标签。
3D场景图的下游应用广泛,涉及到VR、AR都可以运用其结果,当然也包括VQA等。
3D场景图生成模型目前的输入是带有instance但没有label的点云数据。
除了传统的场景图生成面对的问题之外,3DSSG还面临以下几个问题:
- 3D数据例如点云只获取每一个instance的几何结构并可能通过相对的方向或距离从表面上定义这些关系。
- 3DSSG谓词数据集相当小,而且存在长尾分布的问题,其中语义谓词往往比几何谓词更稀少。
如上图所示,3D场景图生成的先驱者SGPN倾向于将物体之间的关系预测简单的常见的几何关系,如上图中的sink
和bath cabinet
。但是实际关系应该是<sink,build in,bath cabinet>
,这种关系富含的语义信息更多,是我们更想要的语义关系。
因此,结合场景图生成任务和自然语言任务的相似性,本文采用视觉-语言辅助训练的方法来提高3D模型的效果。因此如何使用这些信息来提高3D模型的表现就成了需要考虑的问题。传统的方法往往使用2D图片的特征来增强对应的3D点云产生的特征,但是这种方法就要求训练阶段和推理阶段都需要2D数据的参与,不符合实际任务。因此本文想到使用CLIP的方法来提高3D模型的表现,该方法本质上运用了CLIP模型的优越性。
本文的模型就是通过结合来自2D、3D和语言三种监督信号来进行异质结合训练。
Method
这篇文章中上标带有oracle的意思是这些特征来自2D图像。
P ∈ R N × 3 P\in R^{N\times 3} P∈RN×3表示点云, M = { M 1 , . . . , M k } M=\{M_1, ..., M_k\} M={M1,...,Mk}表示没有类别标注的instance mask。以上两个为输入,然后需要预测 G = O , R G={O,R} G=O,R,其中 O = { o i } i = 1 K O=\{o_i\}^K_{i=1} O={oi}i=1K表示场景中所有的物体。 o i o_i oi表示主语, o j o_j oj表示宾语, r i j r_{ij} rij表示谓语,三者构成三元组。
3D Prediction Model
Node Encoder
节点编码器还是选用了PointNet进行编码
Oracle Edge Encoder/3D Edge Encoder:
边的初始特征采用了SGFN的边编码方式,对于每一个instance, b = ( b x , b y , b z ) b = (b_x,b_y,b_z) b=(bx,by,bz)表示长宽高, v = b x b y b z v=b_xb_yb_z v=bxbybz表示它们的体积,最长边 l = m a x ( b x , b y , b z ) l=max(b_x, b_y,b_z) l=max(bx,by,bz)。因此边特征可以通过下述表示:
注意这里虽然用的是pointnet,但是并没有使用上stn3k的旋转矩阵,相当于只有一个节点的点云,所以不需要旋转。只进行特征推理。
Scene Graph Reasoning
使用和SGFN一样的GNN结构,使用多头自注意力机制在不同节点和边之间进行特征传递。GNN模型由多头自注意力模块组成,它们会重复 T T T次来获取最后的节点和边特征,然后再用分类器对节点特征和边特征进行分类。
但是注意这里并不需要完全遵守SGFN的模型结构, S G G p o i n t SGG_{point} SGGpoint的模型结构也可以。后续证明了该论文中提出的cross-model方法在不同的主干网络下都有不俗的表现。
Visual-Linguistic Semantics Assisted Training
2D Images frames: 通过相机的内参和外参将物体的点云映射到一组采样帧当中,然后保留该部分2D图像,之后使用CLIP对图像和对应的文本进行特征提取,然后取一组帧的特征平均值作为该物体的2D特征和文本特征。
**Oracle Node Encoder:**这个用的直接就是CLIP封装好的对图像处理的模型。
**Multi-modal Prediction Model as the Oracle:**这个multi-model预测模型对于3D预测模型来讲叫做oracle模型,它和3Dprediction模型采用相同的结构并且同样学习如何预测3D语义场景图,但是它的特征来自2D的视觉特征。它的视觉特征是从RGB图像集合当中获取,这些RGB图像会对应每一个点云(RGB图像的范围是点云投影到该2D平面的范围)
Node-level Collaboration: 该文章在节点层面和边层面进行特征上的融合,融合的时候也采用的是多头自注意力模型。其中keys 和 values 是来自3D模型的节点和边的特征。query则是来自对应2D图像的节点和边的特征。节点层面的添加了一个距离感知的掩码来消除距离很远的instance产生的不必要的attention。其中mask value的计算方法如下:
对于 μ i , μ j \mu_i,\mu_j μi,μj表示点云instance的 P i , P j P_i,P_j Pi,Pj中心。后面那个表示L2正则化距离,也就是先平方再开根号。
边层面的特征交互没有使用距离感知的掩蔽策略,因为edge之间的距离很难定义,因此将所有的边缘纳入注意力的计算是比较安全的。
obj_feature_2d = self.cross_attn[i](obj_feature_2d, obj_feature_3d, obj_feature_3d, attention_weights=obj_distance_weight, way=attention_matrix_way, attention_mask=obj_mask, use_knn=False) # 节点层面使用了obj_distance_weight
edge_feature_2d = self.cross_attn_rel[i](edge_feature_2d, edge_feature_3d, edge_feature_3d, use_knn=False) # 边层面没有使用权重
Auxiliary Training Strategies
利用CLIP的方式来获取视觉语言的知识。具体来讲,该模型通过对场景图中的一个三元组进行文本上的prompt,构造A scene of a/an [subject][predicate] a/an [object]
,然后利用CLIP的text embedding生成
e
i
j
t
e
x
t
e^{text}_{ij}
eijtext。之后对于GNN层中每一个三元组
{
o
i
o
r
a
c
l
e
,
r
i
j
o
r
a
c
l
e
,
o
j
o
r
a
c
l
e
}
\{o_i^{oracle}, r_{ij}^{oracle}, o_j^{oracle}\}
{oioracle,rijoracle,ojoracle}最小化的文本嵌入和融合triplet特征之间的距离。
其中 ρ ( ⋅ , ⋅ ) \rho(\cdot,\cdot) ρ(⋅,⋅)表示距离函数,可以使用 l 1 l_1 l1正则化函数negative cosine distance。 ∣ ∣ || ∣∣表示指示函数,当它等于1时表示argument是真实的。因此,上述损失函数只考虑有文本描述的三元组。
同时,在3D和2D之间也会构造一个损失函数来计算之间的距离,让3D的特征靠近2D的特征
为了提高模型的表现,2D处的instance encoder是一个固定的经过CLIP训练的模型。此外,为了提高object classifiers的效果,使用CLIP的参数对3D和2D的分类器都进行了初始化。
最后,总的损失函数如下
Experiment
实验一
A@k的含义是topk的准确率,也就是预测正确的数量占总数量的大小。这里的总数量指的是该类别下的关系或者物体的数量。
以预测物体的A@k为例,对于一个物体A,物体类别字典中一共有160种物体,然后预测该物体A的类别。最后发现在物体A的真实类别上预测概率为0.9,且在所有物体类别上排名为1/160,则此时它可以被加入到A@1集合当中,如果排名大于1,则不可以加入到A@1集合当中。最后把加入到A@1集合当中的物体数量除以整个验证集的物体数量,得到A@1的值。
下表是代码运行的结果,这里有疑问:
Model | A@1 | A@5 | A@10 | A@1 | A@3 | A@5 | mA@1 | mA@3 | mA@5 | A@50 | A@100 | mA@50 | mA@100 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
VL-SAT | 56.10 | 77.80 | 85.90 | 90.16 | 98.47 | 99.46 | 50.55 | 74.27 | 88.67 | 90.29 | 92.77 | 64.73 | 74.12 |
VL-SAT(oracle) | 66.87 | 86.47 | 91.12 | 90.53 | 98.37 | 99.38 | 55.59 | 76.67 | 87.16 | 93.04 | 95.09 | 72.49 | 80.19 |
统计结果:
Relation predict result:
Relation: supported by top1: 26(12.04) top3: 125(57.87) top5: 181(83.80) Length: 216
Relation: left top1: 1572(91.77) top3: 1683(98.25) top5: 1710(99.82) Length: 1713
Relation: right top1: 1595(93.11) top3: 1685(98.37) top5: 1708(99.71) Length: 1713
Relation: front top1: 774(75.59) top3: 1007(98.34) top5: 1022(99.80) Length: 1024
Relation: behind top1: 699(68.26) top3: 991(96.78) top5: 1021(99.71) Length: 1024
Relation: close by top1: 1480(83.81) top3: 1688(95.58) top5: 1750(99.09) Length: 1766
Relation: bigger than top1: 26(30.59) top3: 65(76.47) top5: 81(95.29) Length: 85
Relation: smaller than top1: 31(36.47) top3: 61(71.76) top5: 83(97.65) Length: 85
Relation: higher than top1: 137(70.26) top3: 177(90.77) top5: 190(97.44) Length: 195
Relation: lower than top1: 136(69.74) top3: 171(87.69) top5: 192(98.46) Length: 195
Relation: same symmetry as top1: 3(6.00) top3: 22(44.00) top5: 34(68.00) Length: 50
Relation: same as top1: 92(42.99) top3: 181(84.58) top5: 206(96.26) Length: 214
Relation: attached to top1: 888(87.83) top3: 995(98.42) top5: 1003(99.21) Length: 1011
Relation: standing on top1: 1170(86.22) top3: 1330(98.01) top5: 1345(99.12) Length: 1357
Relation: lying on top1: 163(70.26) top3: 213(91.81) top5: 222(95.69) Length: 232
Relation: hanging on top1: 109(68.12) top3: 156(97.50) top5: 159(99.38) Length: 160
Relation: connected to top1: 21(61.76) top3: 23(67.65) top5: 27(79.41) Length: 34
Relation: leaning against top1: 0(0.00) top3: 1(5.26) top5: 4(21.05) Length: 19
Relation: part of top1: 5(83.33) top3: 5(83.33) top5: 6(100.00) Length: 6
Relation: belonging to top1: 14(45.16) top3: 17(54.84) top5: 19(61.29) Length: 31
Relation: build in top1: 27(81.82) top3: 30(90.91) top5: 30(90.91) Length: 33
Relation: standing in top1: 7(28.00) top3: 18(72.00) top5: 21(84.00) Length: 25
Relation: cover top1: 0(0.00) top3: 3(16.67) top5: 14(77.78) Length: 18
Relation: lying in top1: 0(0.00) top3: 2(5.56) top5: 7(19.44) Length: 36
Relation: hanging in top1: 0(0.00) top3: 0(0.00) top5: 0(0.00) Length: 1
Relation: None top1: 27753(95.37) top3: 29069(99.90) top5: 29099(100.00) Length: 29099
Only position
Eval: 3d obj Acc@1 : 46.0695468914647
Eval: 2d obj Acc@1: 63.05584826132771
Eval: 3d obj Acc@5 : 70.43203371970495
Eval: 2d obj Acc@5: 84.0463645943098
Eval: 3d obj Acc@10 : 80.84299262381454
Eval: 2d obj Acc@10: 89.96838777660696
Eval: 3d rel Acc@1 : 90.1070844281394
Eval: 3d mean rel Acc@1 : 40.81087348741942
Eval: 2d rel Acc@1: 90.94987853849587
Eval: 2d mean rel Acc@1: 47.82408087830366
Eval: 3d rel Acc@3 : 98.08636160824946
Eval: 3d mean rel Acc@3 : 62.94893506113851
Eval: 2d rel Acc@3: 98.0566159337663
Eval: 2d mean rel Acc@3: 71.18989247570352
Eval: 3d rel Acc@5 : 99.25387933171385
Eval: 3d mean rel Acc@5 : 75.25894753060192
Eval: 2d rel Acc@5: 99.25140052550691
Eval: 2d mean rel Acc@5: 82.9512407526059
Eval: 3d triplet Acc@50 : 87.03832234395915
Eval: 2d triplet Acc@50: 92.08765058747707
Eval: 3d triplet Acc@100 : 90.10460562193248
Eval: 2d triplet Acc@100: 94.43508006544049
Eval: 3d mean recall@50 : 45.03740310668945
Eval: 2d mean recall@50: 68.3842544555664
Eval: 3d mean recall@100 : 57.1867790222168
Eval: 2d mean recall@100: 77.78561401367188
Eval: 3d zero-shot recall@50 : 15.979175010012014
Eval: 3d zero-shot recall@100: 30.436523828594314
Eval: 3d non-zero-shot recall@50 : 64.20077749828494
Eval: 3d non-zero-shot recall@100: 74.21678481591584
Eval: 3d all-zero-shot recall@50 : 53.49106110468737
Eval: 3d all-zero-shot recall@100: 64.49346259895046
Eval: 3d recall@20 : 0.20853536662134267
Eval: 3d recall@50 : 0.27538180060311623
Eval: 3d recall@100 : 0.32623018792059894
Zeros
Eval: 3d obj Acc@1 : 50.663856691253955
Eval: 2d obj Acc@1: 66.36459430979978
Eval: 3d obj Acc@5 : 74.58377239199157
Eval: 2d obj Acc@5: 86.28029504741833
Eval: 3d obj Acc@10 : 84.0463645943098
Eval: 2d obj Acc@10: 91.46469968387777
Eval: 3d rel Acc@1 : 91.07629765504933
Eval: 3d mean rel Acc@1 : 43.015418442443135
Eval: 2d rel Acc@1: 91.84720638540479
Eval: 2d mean rel Acc@1: 53.65014863265658
Eval: 3d rel Acc@3 : 98.19047146894056
Eval: 3d mean rel Acc@3 : 66.89854773255382
Eval: 2d rel Acc@3: 98.23508998066531
Eval: 2d mean rel Acc@3: 73.95822022892116
Eval: 3d rel Acc@5 : 99.28858261861087
Eval: 3d mean rel Acc@5 : 77.42369241436364
Eval: 2d rel Acc@5: 99.32576471171484
Eval: 2d mean rel Acc@5: 86.41257977931141
Eval: 3d triplet Acc@50 : 88.26037380397601
Eval: 2d triplet Acc@50: 93.3964602647365
Eval: 3d triplet Acc@100 : 91.2027167716028
Eval: 2d triplet Acc@100: 95.43899657924743
Eval: 3d mean recall@50 : 49.7222785949707
Eval: 2d mean recall@50: 71.30230712890625
Eval: 3d mean recall@100 : 61.00960922241211
Eval: 2d mean recall@100: 79.39472961425781
Eval: 3d zero-shot recall@50 : 19.903884661593914
Eval: 3d zero-shot recall@100: 34.281137364837804
Eval: 3d non-zero-shot recall@50 : 68.71712782986508
Eval: 3d non-zero-shot recall@100: 78.1843128287217
Eval: 3d all-zero-shot recall@50 : 57.876011740638624
Eval: 3d all-zero-shot recall@100: 68.43369207506893
Eval: 3d recall@20 : 0.23439225219097373
Eval: 3d recall@50 : 0.2968114841186931
Eval: 3d recall@100 : 0.34787297334654416
binary cross entropy loss with true object label
Triplet的mA是怎么计算的?因为首先对于mA的定义,对于Object来讲,应该是在每一个物体类别上准确率然后平均。那这里Triplet的话就应该在每一个Triplet的类别上计算准确率然后求平均。
但是根据代码来看,代码中用来计算得到Triplet的mA@k的函数名称叫做get_mean_recall
,虽然最后根据代码内容来看计算的实际是准确率,但是它所用的Triplet类别实际上是取的场景中object或者predicate的类别(根据谁的索引更小),也就是说这里Triplet mA@100实际上的平均是在每一个predicate类别或者object类别上进行平均(每一个场景都不同)。
这一处计算方法个人感觉非常奇怪,为什么在取
cls_num
的时候要用max进行取。
实验二
- scene graph classification (SGCls) :同时预测object和predicate的语义类别
- predicate classification (PredCls):object类别给定,预测predicate类别
计算这SGCls和PredCls没有结合到模型的框架当中,但给出了evaluate_triplet_recallk函数,因此我尝试将其结合进模型当中并运行结果。
从结果来看,符合但目前发现代码中一个部分写的可能有问题,后续还需要结合3DSSG代码确定是否有误。
实验三
- CI means CLIP-initialized object classifier.
- NC means node-level collaboration
- EC means edge-level collaboration
- TR means triplet-level CLIP-based regularization
代码部分
数据预处理
数据预处理部分主要工作是将点云中有标注的物体根据相机的内参和外参映射到2D图像当中,并对映射产生的2D图像以物体为中心进行裁剪然后保存。
# 读取relation_train.json或者relation_val.json的内容
for i in selected_scans:
def read_pointcloud(scan_id):
"""读取点云数据"""
plydata = trimesh.load('labels.instances.annotated.v2.ply'), process=False)
points = np.array(plydata.vertices) # 点构成的列表
labels = np.array(plydata.metadata['_ply_raw']['vertex']['data']['objectId']) # 每一个点的实例标签,对应object的instance id
def read_scan_info():
"""
读取场景的信息,包括内参矩阵,外参数矩阵,rgb图像list
注意外参矩阵是每个图像对应一个,而内参矩阵只有一个,因为只有一个相机
"""
return image_list, extrinsic_list, intrinsic_info
def map_pc_to_image():
instance_id = set(instance_names.keys()) # 场景中的物体的instance id
# get clip match rate to filter some transport noise
image_feature = model.encode_image(image_input) # clip编码image图像列表
image_feature /= image_feature.norm(dim=-1, keepdim=True) # [num_img, 512]
similarity = (image_feature @ class_weight.T).softmax(dim=-1) # [num_img, 160]
for i in instance_id:
# 将instance_id对应的点通过相机内参和外参映射到RGB图像当中,然后确定在图像范围内的点的索引 index
indexs = ((c_2_i[...,0]< width) & (c_2_i[...,0]>0) & (c_2_i[...,1]< height) & (c_2_i[...,1]>0))
# 根据物体类别筛选经过CLIP编码的图像特征和类别文本特征相似度较高的image,过滤相似度较低的image
class_idx = class_list.index(instance_names[i])
topk_index = (-similarity[:, class_idx]).argsort()[:topk]
for k in topk_index:
# 先取出点云映射后的像素点坐标,再将物体所占区域经过扩大后裁剪出来,并保存
# 最后将裁剪的图像和原本的图像都用CLIP进行特征提取,然后存储为list。
# 保存被裁减的图像和原图像的均值特征。
读取训练数据
class Dataset
# 首先根据scan id读取points和每个点对应的instance id
data = load_mesh(path, self.mconfig.label_file, self.use_rgb, self.use_normal) # 读取points和instances
points = data['points']
instances = data['instances']
for i, instance_id in enumerate(nodes): # 遍历场景中的物体的instance id
# 获取物体的点集
obj_pointset = points[np.where(instances == instance_id)[0]]
choice = np.random.choice(len(obj_pointset), num_points, replace=True) # 随机选择一组点
obj_pointset = obj_pointset[choice, :]
obj_pointset = torch.from_numpy(obj_pointset.astype(np.float32))
# 用于下一个for循环确定edge的联合bbox
min_box = np.min(obj_pointset[:,:3], 0) - padding
max_box = np.max(obj_pointset[:,:3], 0) + padding
instances_box[instance_id] = (min_box,max_box) # 获取物体对应的bbox的左上角和右下角坐标
obj_pointset[:,:3] = self.zero_mean(obj_pointset[:,:3]) # 中心化,正则化
obj_points[i] = obj_pointset
descriptor[i] = op_utils.gen_descriptor(torch.from_numpy(obj_pointset)[:,:3]) # 生成obj点云的标准差,中心,最长边等函数存在列表里
for e in range(len(edge_indices)): # 遍历边索引
edge = edge_indices[e]
gt_rels[e,:] = adj_matrix_onehot[index1,index2,:] # 获取边索引和边one-hot编码之间的索引表
instance1 = nodes[edge[0]]
instance2 = nodes[edge[1]]
mask1 = (instances == instance1).astype(np.int32) * 1
mask2 = (instances == instance2).astype(np.int32) * 2
mask_ = np.expand_dims(mask1 + mask2, 1) # torch.unsqueeze()
bbox1 = instances_box[instance1] # 获取物体对应的bbox,上一个for循环中确定
bbox2 = instances_box[instance2]
min_box = np.minimum(bbox1[0], bbox2[0])
max_box = np.maximum(bbox1[1], bbox2[1])
# 确定整个点云中哪些点属于二者的联合bbox
filter_mask = (points[:,0] > min_box[0]) * (points[:,0] < max_box[0]) \
* (points[:,1] > min_box[1]) * (points[:,1] < max_box[1]) \
* (points[:,2] > min_box[2]) * (points[:,2] < max_box[2])
# add with context, to distingush the different object's points
points4d = np.concatenate([points, mask_], 1)
pointset = points4d[np.where(filter_mask > 0)[0], :] # 过滤出两个bbox中的点
# 最后再和处理obj点云中一样进行随机取点就可以了
return obj_points, obj_2d_feats, rel_points, gt_rels, label_node, edge_indices, descriptor
训练函数
obj_feature = self.obj_encoder(obj_points) # pointnet [B * num_obj_per_scan, 768]
obj_feature = self.mlp_3d(obj_feature) # MLP
if self.mconfig.USE_SPATIAL:
tmp = descriptor[:,3:].clone() # descriptor:中心点坐标,方差,最大点和最小点的坐标差值,体积,最长边
tmp[:,6:] = tmp[:,6:].log() # only log on volume and length
obj_feature = torch.cat([obj_feature, tmp],dim=-1) # 特征增强
edge_feature = op_utils.Gen_edge_descriptor(flow=self.flow)(descriptor, edge_indices) # 利用MessagePassing来实现edge初始特征的构建,edge特征主要通过descriptor的特征获得
rel_feature_2d = self.rel_encoder_2d(edge_feature) # pointnet
rel_feature_3d = self.rel_encoder_3d(edge_feature) # pointnet
obj_2d_feats = self.clip_adapter(obj_2d_feats) # clip模型处理生成的图像特征再经过一个MLP来方便处理
# 核心模型
gcn_obj_feature_3d, gcn_obj_feature_2d, gcn_edge_feature_3d, gcn_edge_feature_2d = self.mmg(obj_feature, obj_2d_feats, rel_feature_3d, rel_feature_2d,edge_indices, batch_ids, obj_center, descriptor.clone(), istrain=istrain)
# 生成边特征,通过将物体特征和边特征进行拼接再过MLP构造。
gcn_edge_feature_2d_dis = self.generate_object_pair_features(gcn_obj_feature_2d, gcn_edge_feature_2d, edge_indices) # [edge_num, 512 + 512 + 512]
# 最后用Pointnet的分类函数对3D和2D的obj和rel都进行分类
# 计算损失
# 首先是物体的分类损失
loss_obj_3d = F.cross_entropy(obj_logits_3d, gt_cls) # 交叉熵计算分类损失
loss_obj_2d = F.cross_entropy(obj_logits_2d, gt_cls)
# 计算边的分类权重
batch_mean = torch.sum(gt_rel_cls, dim=(0))
zeros = (gt_rel_cls.sum(-1) ==0).sum().unsqueeze(0) # 统计关系类别为None的关系数量
batch_mean = torch.cat([zeros,batch_mean],dim=0) # 记录每种关系类别的数量
weight = torch.abs(1.0 / (torch.log(batch_mean+1)+1)) # +1 to prevent 1 /log(1) = inf
if ignore_none_rel:
weight[0] = 0
weight *= 1e-2 # reduce the weight from ScanNet
weight[torch.where(weight==0)] = weight[0].clone() if not ignore_none_rel else 0# * 1e-3
weight = weight[1:] # 不取None relaiton
# 边的交叉熵损失
loss_rel_3d = F.binary_cross_entropy(rel_cls_3d, gt_rel_cls, weight=weight)
loss_rel_2d = F.binary_cross_entropy(rel_cls_2d, gt_rel_cls, weight=weight)
# 在3D和2D物体之间构造相似度损失,让3D的物体特征靠近2D
loss_mimic = self.cosine_loss(obj_feature_3d, obj_feature_2d, t=0.8)
# 让2D的边特征靠近CLIP产生的文本特征
rel_mimic_2d = F.l1_loss(edge_feature_2d, rel_text_feat) # 模型预测值f(x)和真实值y之间距离的平均值
loss = lambda_o * (loss_obj_2d + loss_obj_3d) + 3 * lambda_r * (loss_rel_2d + loss_rel_3d) + 0.1 * (loss_mimic + rel_mimic_2d)
self.mmg()
该处对应模型的GCN消息传递部分
# 第一个for循环为了获取一个场景下物体两两之间的距离权重,用于后续的特征传递
for i in range(batch_size):
center_A = obj_center[None, idx_i, :].clone().detach().repeat(len(idx_i), 1, 1) # 获取一个场景的中心
center_B = obj_center[idx_i, None, :].clone().detach().repeat(1, len(idx_i), 1)
center_dist = (center_A - center_B)
dist = center_dist.pow(2) # 取平方
dist = torch.sqrt(torch.sum(dist, dim=-1))[:, :, None]
weights = torch.cat([center_dist, dist], dim=-1).unsqueeze(0) # 1 N N 4
dist_weights = self.self_attn_fc(weights).permute(0,3,1,2) # 1 num_heads N N
attention_matrix_way = 'add'
obj_distance_weight[:, :, count:count + len(idx_i), count:count + len(idx_i)] = dist_weights # 存储dist attention权重,后续和obj_feature_3d产生的自注意力一起使用,当做距离权重
for i in range(self.depth):
obj_feature_3d = self.self_attn[i](obj_feature_3d, obj_feature_3d, obj_feature_3d, attention_weights=obj_distance_weight, way=attention_matrix_way, attention_mask=obj_mask, use_knn=False) # 自注意力
obj_feature_2d = self.cross_attn[i](obj_feature_2d, obj_feature_3d, obj_feature_3d, attention_weights=obj_distance_weight, way=attention_matrix_way, attention_mask=obj_mask, use_knn=False) # 交叉注意力
# 然后全部喂入图注意力卷积网络当中,至此obj的特征处理完了
obj_feature_3d, edge_feature_3d = self.gcn_3ds[i](obj_feature_3d, edge_feature_3d, edge_index, istrain=istrain)
obj_feature_2d, edge_feature_2d = self.gcn_2ds[i](obj_feature_2d, edge_feature_2d, edge_index, istrain=istrain)
# 之后处理边的特征,以3D边特征为key和value,以2D边特征作为query。利用cross-attention
edge_feature_2d = self.cross_attn_rel[i](edge_feature_2d, edge_feature_3d, edge_feature_3d, use_knn=False)