# – coding:utf-8 –
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pycocotools.coco as coco
import numpy as np
import torch
import json
import os
import torch.utils.data as data
#利用 torch.utils.data.Dataset对pascalvoc数据集进行封装,方便后面torch.utils.data.DataLoader的加载
class PascalVOC(data.Dataset):
num_classes = 20
default_resolution = [384, 384]
mean = np.array([0.485, 0.456, 0.406],
dtype=np.float32).reshape(1, 1, 3)
std = np.array([0.229, 0.224, 0.225],
dtype=np.float32).reshape(1, 1, 3)
def __init__(self, opt, split):
super(PascalVOC, self).__init__()
self.data_dir = os.path.join(opt.data_dir, 'voc')#数据集存储路径
self.img_dir = os.path.join(self.data_dir, 'images')#图片存储路径
_ann_name = {'train': 'trainval0712', 'val': 'test2007'}
self.annot_path = os.path.join(
self.data_dir, 'annotations',
'pascal_{}.json').format(_ann_name[split])
self.max_objs = 50
self.class_name = ['__background__', "aeroplane", "bicycle", "bird", "boat",
"bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog",
"horse", "motorbike", "person", "pottedplant", "sheep", "sofa",
"train", "tvmonitor"]
#_valid_ids :_Validation_indexs验证索引
self._valid_ids = np.arange(1, 21, dtype=np.int32)
#cat_ids :category_indexs 类别与索引对应的字典
self.cat_ids = {v: i for i, v in enumerate(self._valid_ids)}
self._data_rng = np.random.RandomState(123)
# 这里是为了后面图片增广中的颜色增广的参数
self._eig_val = np.array([0.2141788, 0.01817699, 0.00341571],
dtype=np.float32)
self._eig_vec = np.array([
[-0.58752847, -0.69563484, 0.41340352],
[-0.5832747, 0.00994535, -0.81221408],
[-0.56089297, 0.71832671, 0.41158938]
], dtype=np.float32)
self.split = split
self.opt = opt
print('==> initializing pascal {} data.'.format(_ann_name[split]))
#loading annotations into memory并creating index
#self.coco是COCO类的实例化对象
self.coco = coco.COCO(self.annot_path)
#self.coco.getImgIds()获取满足指定过滤条件的图片索引。
self.images = sorted(self.coco.getImgIds())
self.num_samples = len(self.images)
print('Loaded {} {} samples'.format(split, self.num_samples))
def _to_float(self, x):
return float("{:.2f}".format(x))
def convert_eval_format(self, all_bboxes):
detections = [[[] for __ in range(self.num_samples)] \
for _ in range(self.num_classes + 1)]
for i in range(self.num_samples):
img_id = self.images[i]
for j in range(1, self.num_classes + 1):
if isinstance(all_bboxes[img_id][j], np.ndarray):
detections[j][i] = all_bboxes[img_id][j].tolist()
else:
detections[j][i] = all_bboxes[img_id][j]
return detections
def __len__(self):
return self.num_samples
def save_results(self, results, save_dir):
json.dump(self.convert_eval_format(results),
open('{}/results.json'.format(save_dir), 'w'))
def run_eval(self, results, save_dir):
# result_json = os.path.join(save_dir, "results.json")
# detections = self.convert_eval_format(results)
# json.dump(detections, open(result_json, "w"))
self.save_results(results, save_dir)
os.system('python tools/reval.py ' + \
'{}/results.json'.format(save_dir))
CenterNet:Objects as Points代码解析(三) :CenterNet/src/lib/datasets/dataset/pascal
最新推荐文章于 2021-03-22 15:15:50 发布