STTran 源码解读4:evaluation_recall.py
目标检测及关系预测模型
目标检测和关系预测模型均不经过梯度下降,不进行训练。
目标检测可视化特征通过预训练模型进行推断,关系预测结果来自tranformer编解码器模型
1) input: test数据和train数据一致 输入向量化的图片信息(img_data)
2) process:
img_data——CNN ——实体特征,联合框特征——实体及联合框表示+语义嵌入表示——
时空Transformer——不同类关系分布(distribution)
3) output: spatial_distribution, contact_distribution
模型评估部分
1) evaluate_scene_graph函数:
收集gt 和 pred:
gt: gt_bbox,gt_class,gt_rel
pred: pred_bbox, pred_class, pred_rel, obj_scores, rel_scores(rel_scores来自模型预测,其余变量均来自gt)
2)evaluate_from_dict函数:
获得实体与实体之间的预测关系
3) evaluate_recall函数:
获得真实三元组和预测三元组
根据关系分数排序三元组
计算recall
output:
pred_to_gt:从谓词中匹配GT
pred_5ples: the predicted(id0,id1,cls0,cls1,rel)
rel_scores:[cls_score1,cls_score2,relscore]
4) compute_pred_mateces函数:
用于计算recall
1.计算三元组是否预测正确
2.计算头尾实体bbox_iou是否大于阈值
3.当三元组和bbox_iou大于阈值的情况下,返回给定预测的GT匹配列表