coco 数据集_手把手教你如何用SOLOV2训练自己的数据集

本文提供了一步步的指南,教你如何使用SOLOV2在自己的数据集上进行训练。首先,你需要在SOLO项目中注册自定义数据集,接着修改配置文件以适应你的类别数量和训练设置。最后,启动训练过程。教程适合那些希望入门实例分割且对SOLO系列算法感兴趣的读者。
摘要由CSDN通过智能技术生成

30ddf27061f38698038dfe89c7391598.png
title: 手把手教你如何用SOLOV2训练自己的数据集
date: 2020-07-24 14:59:11
category: 默认分类

本文介绍 手把手教你如何用SOLOV2训练自己的数据集 <!-- more -->

手把手教你如何用SOLOV2训练自己的数据集

本文由林大佬原创,转载请注明出处,来自腾讯、阿里等一线AI算法工程师组成的QQ交流群欢迎你的加入: 1037662480

最近后台很多小伙伴跟我说能不能出一些实例分割训练的教程, 因为网上很多都是关于加速/部署的, 为了满足大家的愿望, 今天特意给大家带来了现在比较火的SOLO系列算法的训练教程. 确实现在关于实力分割的教程都比较复杂, 这篇文章可以让大家轻松地入门SOTA的实例分割方案, 感兴趣的同学也可以给本文点个赞, 转发一下, 你的支持是我们创作的原始动力!

这篇教程不需要任何神力会员权限, 直接从github clone代码, 先将代码准备好, 就可以开始了:

git clone https://github.com/WXinlong/SOLO

现在网上有好几个不同版本的SOLO开源算法, 但是原作者的这个应该是比较权威的吧, 大家可以用这个版本, 笔者用下来, 这个版本具有几个优点:

  1. 它基于mmdetection, 模块化, 代码看起来也比较通熟易懂;
  2. 训练起来没什么坑, 对于没有8卡GPU的同学,用单卡或者两卡也是可以train的, 我们这篇文章会给出大家的具体指导;

但是也有一些缺点:

  1. 代码pytorch1.5跑不起来,更别说现在最新的pytorch1.7了, 需要我们修改过的代码 (兼容pytorch1.5和mmdetection2.0) 可以移步神力平台获取现成的代码;
  2. 代码注册新的dataset有点麻烦, 而且我发现(没有确认) 原始的dataset有bug, 相信很多同学在训练自己的数据集的时候会遇到第一个类别被自动忽略的bug, 当然这个bug已经被我们修复了, 详情也可以移步神力平台, 文末会放出我们的代码链接.

当然, 如果你只是训练我们今天的数据集, 那是足够了, 因为今天的数据集的主角很小很小很小, 但是麻雀虽小五脏俱全. 先来看看SOLOv2的分割效果:

bb43db6a826e1e81455925c1a3c73e8e.png

这个数据集的名字叫做 坚果数据集.

因为它很小, 所以经常被我用来检测一个算法是不是work, 基本上两分钟就可以出结果. 我也强烈建议大家用起来, 关于数据集的下载, 推荐大家看这篇文章, 这篇文章的博主其实将的比较完全了:

https://www.jianshu.com/p/a94b1629f827​www.jianshu.com

这里也贴一下下载:

wget https://github.com/Tony607/detectron2_instance_segmentation_demo/releases/download/V0.1/data.zip

数据集的版权credit@Tony607 , 感谢这位作者的工作.

材料都准备好了, 接下来按照步骤来教授大家如何训练吧.

01. SOLO注册自定义数据集

首先, 我们需要注册一个自己的自定义数据集, 在原始的SOLO项目里面, 具体的注册方式为:

a). 在 mmdet/dataset 文件下, 创建一个 coco_toy.py 的文件, 文件中就是我们要注册的数据类.

b). 给数据类添加代码:

import numpy as np
from pycocotools.coco import COCO
​
from .custom import CustomDataset
from .registry import DATASETS
​
​
@DATASETS.register_module
class CocoToyDataset(CustomDataset):
​
    CLASSES = ('date', 'fig', 'hazelnut')
​
    def load_annotations(self, ann_file):
        self.coco = COCO(ann_file)
        self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
        self.cat2label = {
    cat_id: i for i, cat_id in enumerate(self.cat_ids)}
        self.img_ids = self.coco.get_img_ids()
        data_infos = []
        for i in self.img_ids:
            info = self.coco.load_imgs([i])[0]
            info['filename'] = info['file_name']
            data_infos.append(info)
        return data_infos
​
    def get_ann_info(self, idx):
        img_id = self.data_infos[idx]['id']
        ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
        ann_info = self.coco.load_anns(ann_ids)
        return self._parse_ann_info(self.data_infos[idx], ann_info)
​
    def get_cat_ids(self, idx):
        img_id = self.data_infos[idx]['id']
        ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
        ann_info = self.coco.load_anns(ann_ids)
        return [ann['category_id'] for ann in ann_info]
​
    def _filter_imgs(self, min_size=32):
        """Filter images too small or without ground truths."""
        valid_inds = []
        ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
        for i, img_info in enumerate(self.data_infos):
            if self.filter_empty_gt and self.img_ids[i] not in ids_with_ann:
                continue
            if min(img_info['width'], img_info['height']) >= min_size:
                valid_inds.append(i)
        return valid_inds
​
    def get_subset_by_classes(self):
        """Get img ids that contain any category in class_ids.
​
        Different from the coco.getImgIds(), this function returns the id if
        the img contains one of the categories rather than all.
​
        Args:
            class_ids (list[int]): list of category ids
​
        Return:
            ids (list[int]): integer list of img ids
        """
​
        ids = set()
        for i, class_id in enumerate(self.cat_ids):
            ids |= set(self.coco.cat_img_map[class_id])
        self.img_ids = list(ids)
​
        data_infos = []
        for i in self.img_ids:
            info = self.coco.load_imgs([i])[0]
            info['filename'] = info['file_name']
            data_infos.append(info)
        return data_infos
​
    def _parse_ann_info(self, img_info, ann_info):
        """Parse bbox and mask annotation.
​
        Args:
            ann_info (list[dict]): Annotation info of an image.
            with_mask (bool): Whether to parse mask annotations.
​
        Returns:
            dict: A dict containing the following keys: bboxes, bboxes_ignore,
                labels, masks, seg_map. "masks" are raw annotations and not
                decoded into binary masks.
        """
        gt_bboxes = []
        gt_labels = []
        gt_bboxes_ignore = []
        gt_masks_ann = []
​
        for i, ann in enumerate(ann_info):
            if ann.get('ignore', False):
                continue
            x1, y1, w, h = ann['bbox']
            if ann['area'] <= 0 or w < 1 or h < 1:
                continue
            if ann['category_id'] not in self.cat_ids:
                continue
            bbox = [x1, y1, x1 + w, y1 + h]
            if ann.get('iscrowd', False):
                gt_bboxes_ignore.append(bbox)
            else:
                gt_bb
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值