pycocotools库的使用

本文深入介绍了PyCOCO工具包,包括coco、mask和cocoeval三个模块的功能。coco模块提供获取标注id、类别id和图片id等方法,以及加载和显示标注信息。mask模块包含RLE编码和解码等操作。cocoeval模块用于评估算法性能。此外,还展示了如何利用COCODetection类加载COCO数据集,并进行预处理。
摘要由CSDN通过智能技术生成

1.组成模块

pycocotools下有三个模块:coco、cocoeval、mask、_mask。

(1)coco模块: 

# The following API functions are defined:
#  COCO       - COCO api class that loads COCO annotation file and prepare data structures.
#  getAnnIds  - Get ann ids that satisfy given filter conditions.
#  getCatIds  - Get cat ids that satisfy given filter conditions.
#  getImgIds  - Get img ids that satisfy given filter conditions.
#  loadAnns   - Load anns with the specified ids.
#  loadCats   - Load cats with the specified ids.
#  loadImgs   - Load imgs with the specified ids.
#  annToMask  - Convert segmentation in an annotation to binary mask.
#  showAnns   - Display the specified annotations.
#  loadRes    - Load algorithm results and create API for accessing them.
#  download   - Download COCO images from mscoco.org server.
# Throughout the API "ann"=annotation, "cat"=category, and "img"=image.
# Help on each functions can be accessed by: "help COCO>function".

COCO类定义了10个方法:

(1)获取标注id:

def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
        """
        Get ann ids that satisfy given filter conditions. default skips that filter
        :param imgIds  (int array)     : get anns for given imgs
               catIds  (int array)     : get anns for given cats
               areaRng (float array)   : get anns for given area range (e.g. [0 inf])
               iscrowd (boolean)       : get anns for given crowd label (False or True)
        :return: ids (int array)       : integer array of ann ids
        """

(2)获取类别id:

def getCatIds(self, catNms=[], supNms=[], catIds=[]):
        """
        filtering parameters. default skips that filter.
        :param catNms (str array)  : get cats for given cat names
        :param supNms (str array)  : get cats for given supercategory names
        :param catIds (int array)  : get cats for given cat ids
        :return: ids (int array)   : integer array of cat ids
        """

(3)获取图片id:

def getImgIds(self, imgIds=[], catIds=[]):
        '''
        Get img ids that satisfy given filter conditions.
        :param imgIds (int array) : get imgs for given ids
        :param catIds (int array) : get imgs with all given cats
        :return: ids (int array)  : integer array of img ids
        '''

(4)加载标注信息:

def loadAnns(self, ids=[]):
        """
        Load anns with the specified ids.
        :param ids (int array)       : integer ids specifying anns
        :return: anns (object array) : loaded ann objects
        """

(5)加载类别:

def loadCats(self, ids=[]):
        """
        Load cats with the specified ids.
        :param ids (int array)       : integer ids specifying cats
        :return: cats (object array) : loaded cat objects
        """

(6)加载图片:

def loadImgs(self, ids=[]):
        """
        Load anns with the specified ids.
        :param ids (int array)       : integer ids specifying img
        :return: imgs (object array) : loaded img objects
        """

(7)用matplotlib在图片上显示标注:

def showAnns(self, anns):
        """
        Display the specified annotations.
        :param anns (array of object): annotations to display
        :return: None
        """

(8)加载结果文件:

def loadRes(self, resFile):
        """
        Load result file and return a result api object.
        :param   resFile (str)     : file name of result file
        :return: res (obj)         : result api object
        """

(9)下载数据集:

def download(self, tarDir = None, imgIds = [] ):
        '''
        Download COCO images from mscoco.org server.
        :param tarDir (str): COCO results directory name
               imgIds (list): images to be downloaded
        :return:
        '''

(10)ann(polygons, uncompressed RLE)转为rle格式(0表示背景,1表示分割区域):

def annToRLE(self, ann):
        """
        Convert annotation which can be polygons, uncompressed RLE to RLE.
        :return: binary mask (numpy 2D array)
        """

(11)polygons, uncompressed RLE, or RLE 转mask:

def annToMask(self, ann):
        """
        Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask.
        :return: binary mask (numpy 2D array)
        """

 

2、mask模块下定义了四个函数:

def encode(bimask):
def decode(rleObjs):
def area(rleObjs):
def toBbox(rleObjs):

3、cocoeval模块定义了COCOeval和Params类:

    # The usage for CocoEval is as follows:
    #  cocoGt=..., cocoDt=...       # load dataset and results
    #  E = CocoEval(cocoGt,cocoDt); # initialize CocoEval object
    #  E.params.recThrs = ...;      # set parameters as desired
    #  E.evaluate();                # run per image evaluation
    #  E.accumulate();              # accumulate per image results
    #  E.summarize();               # display summary metrics of results

2.使用pycocotools加载coco数据集

 1.实例化coco类

从coco类的源码中,我们可以看到,初始化方法中执行了createIndex方法,其中返回

字典anns:以标注id为keys,标注信息为values的字典

字典imgs:以图片id为健,图片信息为值的字典

字典imgToAnns:以图片id为健,标注信息为值(列表)的字典

字典cats:以类别id为健,类别信息为值的字典

字典catToImgs:以种类id为健,图片id(list)的字典

class COCO:
    def __init__(self, annotation_file=None):
        """
        Constructor of Microsoft COCO helper class for reading and visualizing annotations.
        :param annotation_file (str): location of annotation file
        :param image_folder (str): location to the folder that hosts images.
        :return:
        """
        # load dataset
        self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
        self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
        if not annotation_file == None:
            print('loading annotations into memory...')
            tic = time.time()
            with open(annotation_file, 'r') as f:
                dataset = json.load(f)
            assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
            print('Done (t={:0.2f}s)'.format(time.time()- tic))
            self.dataset = dataset
            self.createIndex()

    def createIndex(self):
        # create index
        print('creating index...')
        anns, cats, imgs = {}, {}, {}
        imgToAnns,catToImgs = defaultdict(list),defaultdict(list)
        if 'annotations' in self.dataset:
            for ann in self.dataset['annotations']:
                imgToAnns[ann['image_id']].append(ann)
                anns[ann['id']] = ann

        if 'images' in self.dataset:
            for img in self.dataset['images']:
                imgs[img['id']] = img

        if 'categories' in self.dataset:
            for cat in self.dataset['categories']:
                cats[cat['id']] = cat

        if 'annotations' in self.dataset and 'categories' in self.dataset:
            for ann in self.dataset['annotations']:
                catToImgs[ann['category_id']].append(ann['image_id'])

        print('index created!')

        # create class members
        self.anns = anns
        self.imgToAnns = imgToAnns
        self.catToImgs = catToImgs
        self.imgs = imgs
        self.cats = cats

 2.读取数据

我们以tensorflow读取数据为例,pytorch也类似,tensorflow重写的方法keras.utils.Sequence与pytorch需要重写的方法dataloader类似

class COCODetection(Sequence):
    def __init__(self, image_path, coco, num_classes, anchors, batch_size, config, COCO_LABEL_MAP={}, augmentation=None):
        self.image_path     = image_path

        self.coco           = coco
        self.ids            = list(self.coco.imgToAnns.keys())

        self.num_classes    = num_classes
        self.anchors        = anchors
        self.batch_size     = batch_size
        self.config         = config

        self.augmentation   = augmentation

        self.label_map      = COCO_LABEL_MAP
        self.length         = len(self.ids)

    def __getitem__(self, index):
        for i, global_index in enumerate(range(index * self.batch_size, (index + 1) * self.batch_size)):  
            global_index = global_index % self.length

            image, boxes, mask_gt, num_crowds, image_id = self.pull_item(global_index)
            #------------------------------#
            #   获得种类
            #------------------------------#
            class_ids   = boxes[:,  -1]
            #------------------------------#
            #   获得框的坐标
            #------------------------------#
            boxes       = boxes[:, :-1]

            image, image_meta, gt_class_ids, gt_boxes, gt_masks = \
            load_image_gt(image, mask_gt, boxes, class_ids, image_id, self.config, use_mini_mask=self.config.USE_MINI_MASK)

            #------------------------------#
            #   初始化用于训练的内容
            #------------------------------#
            if i == 0:
                batch_image_meta    = np.zeros((self.batch_size,) + image_meta.shape, dtype=image_meta.dtype)
                batch_rpn_match     = np.zeros([self.batch_size, self.anchors.shape[0], 1], dtype=np.int32)
                batch_rpn_bbox      = np.zeros([self.batch_size, self.config.RPN_TRAIN_ANCHORS_PER_IMAGE, 4], dtype=np.float32)
                batch_images        = np.zeros((self.batch_size,) + image.shape, dtype=np.float32)
                batch_gt_class_ids  = np.zeros((self.batch_size, self.config.MAX_GT_INSTANCES), dtype=np.int32)
                batch_gt_boxes      = np.zeros((self.batch_size, self.config.MAX_GT_INSTANCES, 4), dtype=np.int32)
                batch_gt_masks      = np.zeros((self.batch_size, gt_masks.shape[0], gt_masks.shape[1], self.config.MAX_GT_INSTANCES), dtype=gt_masks.dtype)

            if not np.any(gt_class_ids > 0):
                continue
        
            # RPN Targets
            rpn_match, rpn_bbox = build_rpn_targets(image.shape, self.anchors, gt_class_ids, gt_boxes, self.config)
        
            #-----------------------------------------------------------------------#
            #   如果某张图片里面物体的数量大于最大值的话,则进行筛选,防止过大
            #-----------------------------------------------------------------------#
            if gt_boxes.shape[0] > self.config.MAX_GT_INSTANCES:
                ids = np.random.choice(
                    np.arange(gt_boxes.shape[0]), self.config.MAX_GT_INSTANCES, replace=False)
                gt_class_ids = gt_class_ids[ids]
                gt_boxes = gt_boxes[ids]
                gt_masks = gt_masks[:, :, ids]

            #------------------------------#
            #   将当前信息加载进batch
            #------------------------------#
            batch_image_meta[i]                             = image_meta
            batch_rpn_match[i]                              = rpn_match[:, np.newaxis]
            batch_rpn_bbox[i]                               = rpn_bbox
            batch_images[i]                                 = preprocess_input(image.astype(np.float32))
            batch_gt_class_ids[i, :gt_class_ids.shape[0]]   = gt_class_ids
            batch_gt_boxes[i, :gt_boxes.shape[0]]           = gt_boxes
            batch_gt_masks[i, :, :, :gt_masks.shape[-1]]    = gt_masks
        return [batch_images, batch_image_meta, batch_rpn_match, batch_rpn_bbox, batch_gt_class_ids, batch_gt_boxes, batch_gt_masks], \
               [np.zeros(self.batch_size), np.zeros(self.batch_size), np.zeros(self.batch_size), np.zeros(self.batch_size), np.zeros(self.batch_size)]

    def __len__(self):
        return math.ceil(len(self.ids) / float(self.batch_size))

    def pull_item(self, index):
        #------------------------------#
        #   载入coco序号
        #   根据coco序号载入目标信息
        #------------------------------#
        image_id    = self.ids[index]
        target      = self.coco.loadAnns(self.coco.getAnnIds(imgIds = image_id))

        #------------------------------#
        #   根据目标信息判断是否为
        #   iscrowd
        #------------------------------#
        target      = [x for x in target if not ('iscrowd' in x and x['iscrowd'])]
        crowd       = [x for x in target if ('iscrowd' in x and x['iscrowd'])]
        num_crowds  = len(crowd)
        #------------------------------#
        #   将不是iscrowd的目标
        #       是iscrowd的目标进行堆叠
        #------------------------------#
        target      += crowd

        image_path  = osp.join(self.image_path, self.coco.loadImgs(image_id)[0]['file_name'])
        image       = Image.open(image_path)
        image       = cvtColor(image)
        image       = np.array(image, np.float32)
        height, width, _ = image.shape

        if len(target) > 0:
            masks = np.array([self.coco.annToMask(obj).reshape(-1) for obj in target], np.float32)
            masks = masks.reshape((-1, height, width)) 

            boxes_classes = []
            for obj in target:
                bbox        = obj['bbox']
                final_box   = [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3], self.label_map[obj['category_id']]]
                boxes_classes.append(final_box)
            boxes_classes = np.array(boxes_classes, np.float32)
            boxes_classes[:, [0, 2]] /= width
            boxes_classes[:, [1, 3]] /= height

        if self.augmentation is not None:
            if len(boxes_classes) > 0:
                image, masks, boxes, labels = self.augmentation(image, masks, boxes_classes[:, :4], {'num_crowds': num_crowds, 'labels': boxes_classes[:, 4]})
                num_crowds  = labels['num_crowds']
                labels      = labels['labels']
                if num_crowds > 0:
                    labels[-num_crowds:] = -1
                boxes       = np.concatenate([boxes, np.expand_dims(labels, axis=1)], -1)
        
        masks               = np.transpose(masks, [1, 2, 0])
        outboxes            = np.zeros_like(boxes)
        outboxes[:, [0, 2]] = boxes[:, [1, 3]] * self.config.IMAGE_SHAPE[0]
        outboxes[:, [1, 3]] = boxes[:, [0, 2]] * self.config.IMAGE_SHAPE[1]
        outboxes[:, -1]     = boxes[:, -1]
        outboxes            = np.array(outboxes, np.int)
        return image, outboxes, masks, num_crowds, image_id

 

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

樱花的浪漫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值