BBN 论文代码分析

 get_annotations()

    train_set = eval(cfg.DATASET.DATASET)("train", cfg)
    valid_set = eval(cfg.DATASET.DATASET)("valid", cfg)

    annotations = train_set.get_annotations() # 用于返回训练数据集的标注信息。
    def get_annotations(self):
        annos = []
        for target in self.targets:
            annos.append({'category_id': int(target)})
        return annos

假设训练数据集中有以下类别标注:

train_set.targets = [0, 1, 0, 2, 1, 2, 2, 0, 1, 2]

那么 get_annotations() 返回的 annotations 列表如下:

[
    {'category_id': 0},
    {'category_id': 1},
    {'category_id': 0},
    {'category_id': 2},
    {'category_id': 1},
    {'category_id': 2},
    {'category_id': 2},
    {'category_id': 0},
    {'category_id': 1},
    {'category_id': 2}
]

get_num_classes()

num_classes = train_set.get_num_classes()
    def get_cls_num_list(self):
        cls_num_list = []
        for i in range(self.cls_num): # cls_num = 10
            cls_num_list.append(self.num_per_cls_dict[i])
        return cls_num_list

get_category_list()

def get_category_list(annotations, num_classes, cfg):
    num_list = [0] * num_classes
    cat_list = []
    print("Weight List has been produced")
    for anno in annotations:
        category_id = anno["category_id"]
        num_list[category_id] += 1
        cat_list.append(category_id)
    return num_list, cat_list # 返回包含每个类别样本数量的 num_list 和所有样本类别标识符的 cat_list

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值