get_annotations()
train_set = eval(cfg.DATASET.DATASET)("train", cfg)
valid_set = eval(cfg.DATASET.DATASET)("valid", cfg)
annotations = train_set.get_annotations() # 用于返回训练数据集的标注信息。
def get_annotations(self):
annos = []
for target in self.targets:
annos.append({'category_id': int(target)})
return annos
假设训练数据集中有以下类别标注:
train_set.targets = [0, 1, 0, 2, 1, 2, 2, 0, 1, 2]
那么 get_annotations()
返回的 annotations
列表如下:
[
{'category_id': 0},
{'category_id': 1},
{'category_id': 0},
{'category_id': 2},
{'category_id': 1},
{'category_id': 2},
{'category_id': 2},
{'category_id': 0},
{'category_id': 1},
{'category_id': 2}
]
get_num_classes()
num_classes = train_set.get_num_classes()
def get_cls_num_list(self):
cls_num_list = []
for i in range(self.cls_num): # cls_num = 10
cls_num_list.append(self.num_per_cls_dict[i])
return cls_num_list
get_category_list()
def get_category_list(annotations, num_classes, cfg):
num_list = [0] * num_classes
cat_list = []
print("Weight List has been produced")
for anno in annotations:
category_id = anno["category_id"]
num_list[category_id] += 1
cat_list.append(category_id)
return num_list, cat_list # 返回包含每个类别样本数量的 num_list 和所有样本类别标识符的 cat_list