scene graph generation 计算mean recall数据的过程:

前言:

计算流程这里参考maskrcnn_benchmark/data/datasets/evaluation/vg/sgg_eval.py这个scene graph generation benchmark的github官网来完成相关的任务。

计算mean recall的详细过程

以下是如何利用预测三元组和groundtruth三元组计算mean recall的详细过程:

1. 准备数据

使用如下两个变量来分别保存groundtruth三元组和predicate的三元组

  • prepare_gt方法会处理groundtruth三元组数据,并将其存储在一个字典中。
  • prepare_pred方法会处理预测三元组数据,并将其存储在一个字典中。

2. 计算每个类别的recall

  • calculate_recall方法会遍历所有的groundtruth和predicate数据计算每个关系类别的recall

  • 对于每个类别,计算公式为:
    R e c a l l = T P T P + F N Recall = \frac{TP}{TP+FN} Recall=TP+FNTP

    其中,TP是True Positives,FN是False Negatives。(这句的意思看这句代码就理解了,即:float(len(match)) / float(gt_rels.shape[0])也就是说,(正确匹配的三元组)/所有groundtruth三元组

  1. 计算mean recall
    • calculate_mean_recall方法会计算所有类别的平均recall。
    • 首先,计算每个类别的recall
    • 然后,计算所有类别recall的平均值

M e a n R e c a l l = ∑ R e c a l l i N Mean Recall = \frac{\sum{Recall_i}}{N} MeanRecall=NRecalli

其中, R e c a l l i Recall_i Recalli是第 i i i个关系类别的 r e c a l l recall recall N N N是类别的总数

具体代码片段

以下是一些关键的代码片段的解释:(这些片段是从github文件中专门拿出来的)

实际上只读最下面的完整代码即可有不懂的再看这些代码片段

准备groundtruth数据

def prepare_gt(self):
    for gt in self.gts:
        gt_entry = {}
        gt_entry['relations'] = gt['relations']
        gt_entry['boxes'] = gt['boxes']
        gt_entry['labels'] = gt['labels']
        self.gt_entries.append(gt_entry)

准备预测数据

def prepare_pred(self):
    for pred in self.preds:
        pred_entry = {}
        pred_entry['relations'] = pred['relations']
        pred_entry['boxes'] = pred['boxes']
        pred_entry['labels'] = pred['labels']
        self.pred_entries.append(pred_entry)

计算recall

注意这里的TP和FN,就是上面公式了的TP和FN

def calculate_recall(self):
    for i, gt_entry in enumerate(self.gt_entries):
        pred_entry = self.pred_entries[i]
        for rel in gt_entry['relations']:
            gt_rel = (rel[0], rel[1], rel[2])
            if gt_rel in pred_entry['relations']:
                self.tp[rel[2]] += 1
            else:
                self.fn[rel[2]] += 1

计算mean recall

计算每一个关系的mean recall的流程,就是:
float(len(match)) / float(gt_rels.shape[0] 也就是预测中的gt三元组数/总的gt三元组数

计算mean recall呢,就是把每个关系计算的mean recall ,最后除以关系数量。 (VG150是50个关系数量)

def calculate_mean_recall(self):
    recalls = []
    for i in range(self.num_rel_classes):
        if self.tp[i] + self.fn[i] > 0:
            recalls.append(self.tp[i] / (self.tp[i] + self.fn[i]))
    mean_recall = sum(recalls) / len(recalls)
    return mean_recall

完整代码

只读下面完整代码即可,简洁明快

# 从groundtruth和predicate数据中,按照image_id取出image_id对应的三元组
def prepare_data(gt_data, pred_data):    
    gt_dict = {}
    pred_dict = {}

    for gt in gt_data:
        gt_dict[gt['image_id']] = gt['gt_triplets']

    for pred in pred_data:
        pred_dict[pred['image_id']] = [triplet['triplets'] for triplet in pred['pred_triplets']]

    return gt_dict, pred_dict

# 这个函数, gt_dict和pred_dict是prepare_data函数的返回数据,relationships是50个关系种类,top_k是指计算20recall,50recall,100recall的过程。
# Recall的计算公式: Recall = TP / (TP+FN)
# 其中TP是True Positives就是Groundtruth三元组里,被predicate数据预测到的那部分
# FN是False Negatives,就是Groundtruth三元组里。
# 这个公式的意思看这句代码就理解了,即:`float(len(match)) / float(gt_rels.shape[0]`)也就是说,(正确匹配的三元组)/所有groundtruth三元组)
def calculate_recall(gt_dict, pred_dict, relationships, top_k):    
    tp = defaultdict(int)    # Groundtruth三元组里,被predicate数据预测到的那部分
    fn = defaultdict(int)    # Groundtruth三元组里,没有predicate数据预测到的那部分   TP+FN 加起来是 GT三元组的总数

    for image_id in gt_dict:    # 对gt_dict中的每一个image_id
        if image_id not in pred_dict:
            continue

        gt_triplets = gt_dict[image_id]   # 取出当前id图片的gt_riplets
        pred_triplets = pred_dict[image_id][:top_k]    # 取出当前id图片的pred_riplets

        for gt_triplet in gt_triplets:
            subject, predicate, obj = gt_triplet    # 分析三元组(这里我的三元组格式是<subject,predicate,object>)
            if predicate in relationships:    # 如果当前的关系在50个关系类别里,我们再按recall公式计算
                if gt_triplet in pred_triplets:
                    tp[predicate] += 1    # 看recall公式,当前谓词的tp+1
                else:
                    fn[predicate] += 1    # 看recall公式,当前谓词的fn+1
                    
                   

    # 上面那些代码已经把所有的image_id遍历完了,现在开始按关系来计算关系的recall值
    recall = {}
    for predicate in relationships:
        if tp[predicate] + fn[predicate] > 0:   # 否则这个predicate就是不在50个关系类别里的谓词,加起来等于0,不可能大于0
            recall[predicate] = tp[predicate] / (tp[predicate] + fn[predicate])    # recall公式
        else:
            recall[predicate] = 0.0

    return recall

# 每个关系的recall值都得到了以后,计算所有关系的mean recall值
# 实际上
def calculate_mean_recall(recall):
    recall_values = list(recall.values())
    mean_recall = sum(recall_values) / len(recall_values)
    return mean_recall

若要调用以上的代码函数的话,可以这样调用

gt_file_path = 'gt_data.json'
pred_file_path = 'pred_data.json'

gt_data = load_gt_data(gt_file_path)
pred_data = load_pred_data(pred_file_path)

#数据获取完以后,比如gt_data,是image_id对应一系列gt的triplets,比如pred_dict,是image_id对应一系列pred的triplets
gt_dict, pred_dict = prepare_data(gt_data, pred_data)

relationships = ['above', 'across', 'against', 'along', 'and', 'at', 'attached to', 'behind', 'belonging to',
                 'between', 'carrying', 'covered in', 'covering', 'eating', 'flying in', 'for', 'from', 'growing on',
                 'hanging from', 'has', 'holding', 'in', 'in front of', 'laying on', 'looking at', 'lying on',
                 'made of', 'mounted on', 'near', 'of', 'on', 'on back of', 'over', 'painted on', 'parked on',
                 'part of', 'playing', 'riding', 'says', 'sitting on', 'standing on', 'to', 'under', 'using',
                 'walking in', 'walking on', 'watching', 'wearing', 'wears', 'with']

# 计算每个关系类别的recall值
recall_20 = calculate_recall(gt_dict, pred_dict, relationships, top_k=20)
recall_50 = calculate_recall(gt_dict, pred_dict, relationships, top_k=50)
recall_100 = calculate_recall(gt_dict, pred_dict, relationships, top_k=100)

# 计算总的mean recall值
mean_recall_20 = calculate_mean_recall(recall_20)
mean_recall_50 = calculate_mean_recall(recall_50)
mean_recall_100 = calculate_mean_recall(recall_100)

print(f"Mean Recall @20: {mean_recall_20}")
print(f"Mean Recall @50: {mean_recall_50}")
print(f"Mean Recall @100: {mean_recall_100}")

希望对大家有帮助 。有不会的及时留言。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值