YOLOv5训练自己的数据集【较全面】

Python 同时被 2 个专栏收录
63 篇文章 0 订阅
29 篇文章 4 订阅

前言

这里主要介绍 ultralytics/yolov5 即所谓的 U版YOLO v5,关于YOLO5的原理可以自行搜索相关知识,根据笔者自己测试,效果确实不错,所以这里记录下自己的调试经验。

实验环境

单块 Nvidia GTX 2070 Super, Torch 1.5.0 / torchvision 0.6.0 / python 3.7

代码调试部分

1 源码获取

目前YOLO5社区活跃,更新很快,本人由于不想升级自己的Torch 版本,所以使的是 U版 YOLO 5的 version 2.0
版本,获取链接点击进入
【注意也要下载对应的预训练模型,在网页上的 asserts 部分】

2 自定义数据集

这个其实在官网上也有,点击进入,这里需要注意的是YOLO 5的类别要求是从 0 索引开始的,一定要注意这个,具体的目标图像和标注的目录结构如图所示:

在这里插入图片描述

其中标签的格式为 类别索引,实例中心 x y 和 实例宽 高 w h 其中 x,y,w,h均为相对于图像的比值,一定要注意。
在这里插入图片描述

2.1 自定义数据集生成标签代码

其实上面的 images中的 train 和 val 的文件夹,如果不想把自己的图像再次复制一遍占用磁盘空间,可以直接使用 软连接
ln -s /yourData Dir /images/train
这样也是可以的,就可以省去复制图像的空间了,至于 labels 文件夹下只能自己生成,这里本人提供一种从 VOC格式的 Annotations转Yolo5的labels代码:

3 YOLO 5 实验环境配置

YOLO5中 官网给的 install 步骤很简单,即直接 pip install -U -r requirements.txt ,但是这样会下载一些库的最新版本,如果您的实验环境中还有其他程序,像我只有一块GPU,也实在不想因为某些库的变化,导致其他程序跑不了了,所以,本人的步骤大致为下:
1 创建虚拟环境:
conda create -n yolo5 python=3.7

2 切换至当前的 yolo5环境
conda activate yolo5

3 如果您放心,可以直接使用 pip install -U requirements.txt
如果您不放心,那就将 requirements.txt中的库一条条的安装,比如:
pip install Cython .

由于我的很多库之前都装好了,所以我就没再安装其他的库。
到此假设您的环境配置没问题,下面就是调试代码部分了。

4 训练自己的数据集

YOLO5的代码写的其实挺不错,我们在训练自己的数据集过程中,只需要改变部分参数就行了,下面一一来讲。

4.1 yolov5-2.0\data文件夹

该文件夹目录如下,这里主要是对训练和测试数据的一个基本参数配置,包含了您数据集的 images文件夹,类的种数和具体类标 其实我们主要是新建一个yaml文件,假设命名为 mydata.yaml,格式就仿照该文件夹下任意一个yaml文件即可,推荐仿照 coco128.yaml。
在这里插入图片描述

# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
train: ../YourDataDir/images/train/  # 128 images
val: ../YourDataDir/images/val/  # 128 images

# number of classes
nc: K

# class names
names: ['zhonglei_1','zhonglei_2','zhonglei_3','zhonglei_4',...,'zhonglei_K']

4.2 yolov5-2.0\models文件夹

在这里插入图片描述

这个文件夹主要存放的是 和模型相关的配置文件,一般情况下,就检测来讲,如果不是很难的数据集,YOLO5s 就会有很好的表现,这主要原因是YOLO5在训练时候运用了数据增强技术,貌似是"马赛克"(mosaic)增强,YOLO4貌似也是这个。
但是注意我们需要调整的参数主要是 类别数量 和 anchor 参数,其余参数也可以调整,不过一般建议默认吧,anchors 参数为您图像缩放为 640*640 对应对的大小,一般用KMeans进行聚类选取
在这里插入图片描述
Anchor 聚类选取的代码如下:
这个代码是 utils/utils.py 中的 一个方法,这个方法主要是假设您自己配置的 anchor 最高metric 评价值如果小于 0.99 ,该方法就会自动调用,帮您重新生成Anchors 自己可以看看琢磨下。

def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
    """ Creates kmeans-evolved anchors from training dataset

        Arguments:
            path: path to dataset *.yaml, or a loaded dataset
            n: number of anchors
            img_size: image size used for training
            thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
            gen: generations to evolve anchors using genetic algorithm

        Return:
            k: kmeans evolved anchors

        Usage:
            from utils.utils import *; _ = kmean_anchors()
    """
    thr = 1. / thr

    def metric(k, wh):  # compute metrics
        r = wh[:, None] / k[None]
        x = torch.min(r, 1. / r).min(2)[0]  # ratio metric
        # x = wh_iou(wh, torch.tensor(k))  # iou metric
        return x, x.max(1)[0]  # x, best_x

    def fitness(k):  # mutation fitness
        _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
        return (best * (best > thr).float()).mean()  # fitness

    def print_results(k):
        k = k[np.argsort(k.prod(1))]  # sort small to large
        x, best = metric(k, wh0)
        bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n  # best possible recall, anch > thr
        print('thr=%.2f: %.4f best possible recall, %.2f anchors past thr' % (thr, bpr, aat))
        print('n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thr=%.3f-mean: ' %
              (n, img_size, x.mean(), best.mean(), x[x > thr].mean()), end='')
        for i, x in enumerate(k):
            print('%i,%i' % (round(x[0]), round(x[1])), end=',  ' if i < len(k) - 1 else '\n')  # use in *.cfg
        return k

    if isinstance(path, str):  # *.yaml file
        with open(path) as f:
            data_dict = yaml.load(f, Loader=yaml.FullLoader)  # model dict
        from utils.datasets import LoadImagesAndLabels
        dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
    else:
        dataset = path  # dataset

    # Get label wh
    shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
    wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])  # wh

    # Filter
    i = (wh0 < 3.0).any(1).sum()
    if i:
        print('WARNING: Extremely small objects found. '
              '%g of %g labels are < 3 pixels in width or height.' % (i, len(wh0)))
    wh = wh0[(wh0 >= 2.0).any(1)]  # filter > 2 pixels

    # Kmeans calculation
    from scipy.cluster.vq import kmeans
    print('Running kmeans for %g anchors on %g points...' % (n, len(wh)))
    s = wh.std(0)  # sigmas for whitening
    k, dist = kmeans(wh / s, n, iter=30)  # points, mean distance
    k *= s
    wh = torch.tensor(wh, dtype=torch.float32)  # filtered
    wh0 = torch.tensor(wh0, dtype=torch.float32)  # unflitered
    k = print_results(k)

    # Plot
    # k, d = [None] * 20, [None] * 20
    # for i in tqdm(range(1, 21)):
    #     k[i-1], d[i-1] = kmeans(wh / s, i)  # points, mean distance
    # fig, ax = plt.subplots(1, 2, figsize=(14, 7))
    # ax = ax.ravel()
    # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
    # fig, ax = plt.subplots(1, 2, figsize=(14, 7))  # plot wh
    # ax[0].hist(wh[wh[:, 0]<100, 0],400)
    # ax[1].hist(wh[wh[:, 1]<100, 1],400)
    # fig.tight_layout()
    # fig.savefig('wh.png', dpi=200)

    # Evolve
    npr = np.random
    f, sh, mp, s = fitness(k), k.shape, 0.9, 0.1  # fitness, generations, mutation prob, sigma
    pbar = tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm')  # progress bar
    for _ in pbar:
        v = np.ones(sh)
        while (v == 1).all():  # mutate until a change occurs (prevent duplicates)
            v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
        kg = (k.copy() * v).clip(min=2.0)
        fg = fitness(kg)
        if fg > f:
            f, k = fg, kg.copy()
            pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f
            if verbose:
                print_results(k)

    return print_results(k)

metric 方法如下:

def metric(k):  # compute metric
        r = wh[:, None] / k[None]
        x = torch.min(r, 1. / r).min(2)[0]  # ratio metric
        best = x.max(1)[0]  # best_x
        return (best > 1. / thr).float().mean()  #  best possible recall

    bpr = metric(m.anchor_grid.clone().cpu().view(-1, 2))
    print('Best Possible Recall (BPR) = %.4f' % bpr, end='')
    if bpr < 0.99:  # threshold to recompute
        print('. Attempting to generate improved anchors, please wait...' % bpr)
        na = m.anchor_grid.numel() // 2  # number of anchors
        new_anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
        new_bpr = metric(new_anchors.reshape(-1, 2))

4.3 train.py 调整

注意看 trian.py 中的 198 行,将 80 调整为 您的种类数,这个是个分类的超参数,应该是分类损失的权重,不调整其实也可以,最好把他设置默认为 1
在这里插入图片描述

另外需要调整的就是我画横线的超参数,这些调整为自己配置的,就可以开始训练了。
在这里插入图片描述

4.4 训练模型

yolo5s.pt是从官网下载的预训练模型,如果不写 ./weights/yolo5s.pt 直接按照官网上的 --weigths yolo5s 这就会下载 yolo5s 模型,由于我们使用的是 2.0版本,直接下载可能下载不下来,所以自己最好提前下载好,放在 weights 文件夹下

python train.py --weights ./weights/yolo5s.pt

【提示】一般来讲,训练10轮以上 模型就会出现明显的收敛,如果没有出现,最好关注下自己的数据集 或者 换个模型。

【上述内容中 产生 labels 还有不完善的地方,明天补充下】

  • 3
    点赞
  • 7
    评论
  • 15
    收藏
  • 一键三连
    一键三连
  • 扫一扫,分享海报

<p class="MsoNormal"><span style="font-family: '微软雅黑',sans-serif;">YOLO</span><span style="font-family: '微软雅黑',sans-serif;">系列是基于深度学习的端到端实时目标检测方法。 PyTorch版的YOLOv5轻量而性能高,更加灵活和易用,当前非常流行。</span><span style="font-family: 微软雅黑, sans-serif;"> </span></p> <p class="MsoNormal"><span style="font-family: '微软雅黑',sans-serif;">本课程将手把手地教大家使用labelImg标注和使用YOLOv5训练自己数据集。课程实战分为两个项目:单目标检测(足球目标检测)和多目标检测(足球和梅西同时检测)。</span><span style="font-family: 微软雅黑, sans-serif;"> </span></p> <p class="MsoNormal"><span style="font-family: '微软雅黑',sans-serif;">本课程的YOLOv5使用ultralytics/yolov5,在<span style="color: #e03e2d;"><strong>Windows</strong></span>系统上做项目演示。包括:安装YOLOv5、标注自己数据集、准备自己数据集、修改配置文件、使用wandb训练可视化工具、训练自己数据集、测试训练出的网络模型和性能统计。</span><span style="font-family: 微软雅黑, sans-serif;"> </span></p> <p class="MsoNormal"><span style="font-family: 微软雅黑, sans-serif;">希望学习Ubuntu上演示的同学,请前往 </span><span style="font-family: 微软雅黑, sans-serif;">《</span><span style="font-family: 微软雅黑, sans-serif;">YOLOv5(PyTorch)</span><span style="font-family: 微软雅黑, sans-serif;">实战:训练自己数据集(Ubuntu)》课程链接:https://edu.csdn.net/course/detail/30793</span><span style="font-family: 宋体;"><span style="font-size: 14px;"> </span></span></p> <p style="margin-left: 0cm;"> </p> <p style="margin-left: 0cm;">本人推出了有关YOLOv5目标检测的系列课程。请持续关注该系列的其它视频课程,包括:</p> <p>《YOLOv5(PyTorch)目标检测实战:训练自己数据集》</p> <p>Ubuntu系统 <strong><a href="https://edu.csdn.net/course/detail/30793"><span style="color: #7c79e5;">https://edu.csdn.net/course/detail/30793</span></a></strong></p> <p>Windows系统 <strong><a href="https://edu.csdn.net/course/detail/30923"><span style="color: #7c79e5;">https://edu.csdn.net/course/detail/30923</span></a></strong></p> <p class="MsoNormal" style="background: white; word-break: break-all;" align="left"><span style="font-family: '微软雅黑',sans-serif;">《<span lang="EN-US">YOLOv5(PyTorch)</span>目标检测:原理与源码解析》课程链接:<strong><span lang="EN-US"><a href="https://edu.csdn.net/course/detail/31428"><span style="color: windowtext; font-weight: normal; text-decoration-line: none;">https://edu.csdn.net/course/detail/31428</span></a></span></strong></span></p> <p class="MsoNormal" style="background: white; word-break: break-all;" align="left"><span style="font-family: '微软雅黑',sans-serif;">《<span lang="EN-US">YOLOv5</span>目标检测实战:<span lang="EN-US">Flask Web</span>部署》课程链接:<strong><span lang="EN-US"><a href="https://edu.csdn.net/course/detail/31087"><span style="color: windowtext; font-weight: normal; text-decoration-line: none;">https://edu.csdn.net/course/detail/31087</span></a></span></strong></span></p> <p class="MsoNormal" style="background: white; word-break: break-all;" align="left"><span style="font-family: '微软雅黑',sans-serif;">《<span lang="EN-US">YOLOv5(PyTorch)</span>目标检测实战:<span lang="EN-US">TensorRT</span>加速部署》课程链接<strong>:<span lang="EN-US"><a href="https://edu.csdn.net/course/detail/32303"><span style="color: windowtext; font-weight: normal; text-decoration-line: none;">https://edu.csdn.net/course/detail/32303</span></a></span></strong></span></p> <p class="MsoNormal" style="background: white; word-break: break-all;" align="left"><span style="font-family: '微软雅黑',sans-serif;">《<span lang="EN-US">YOLOv5</span>目标检测实战:<span lang="EN-US">Jetson Nano</span>部署》课程链接:<span lang="EN-US"><a href="https://edu.csdn.net/course/detail/32451" data-cke-saved-href="https://edu.csdn.net/course/detail/32451"><span style="color: windowtext; text-decoration-line: none;">https://edu.csdn.net/course/detail/32451</span></a></span></span></p> <p class="MsoNormal" style="background: white; word-break: break-all;" align="left"><span style="font-family: '微软雅黑',sans-serif;">《<span lang="EN-US">YOLOv5+DeepSORT</span>多目标跟踪与计数精讲》课程链接:<strong><span lang="EN-US"><a href="https://edu.csdn.net/course/detail/32669"><span style="color: windowtext; font-weight: normal; text-decoration-line: none;">https://edu.csdn.net/course/detail/32669</span></a></span></strong></span></p> <p class="MsoNormal" style="background: white; word-break: break-all;" align="left"><span style="font-family: '微软雅黑',sans-serif;">《<span lang="EN-US">YOLOv5</span>实战口罩佩戴检测》课程链接:<strong><span lang="EN-US"><a href="https://edu.csdn.net/course/detail/32744"><span style="color: windowtext; font-weight: normal; text-decoration-line: none;">https://edu.csdn.net/course/detail/32744</span></a></span></strong></span></p> <p class="MsoNormal"><span style="font-family: '微软雅黑',sans-serif;">《<span lang="EN-US">YOLOv5</span>实战中国交通标志识别》课程链接:<span lang="EN-US"><a href="https://edu.csdn.net/course/detail/32744"><span style="color: windowtext; text-decoration-line: none;">https://edu.csdn.net/course/detail/35209</span></a></span></span></p> <p class="MsoNormal"><span style="font-family: '微软雅黑',sans-serif;">《<span lang="EN-US">YOLOv5</span>实战垃圾分类目标检测》课程链接:<span lang="EN-US"><a href="https://edu.csdn.net/course/detail/35284"><span style="color: windowtext; text-decoration-line: none;">https://edu.csdn.net/course/detail/35284</span></a></span></span></p> <p class="MsoNormal"><span style="font-family: '微软雅黑',sans-serif;"><span lang="EN-US"><span style="color: windowtext; text-decoration-line: none;"><img src="https://img-bss.csdnimg.cn/202106220205279501.jpg" alt="课程内容" /></span></span></span></p> <p class="MsoNormal"><span lang="EN-US" style="font-family: '微软雅黑',sans-serif;"> </span></p> <p> </p> <p class="MsoNormal"><span lang="EN-US"> <img src="https://img-bss.csdnimg.cn/202106220204372263.jpg" alt="图片检测效果" /></span></p> <p> </p> <p> </p> <p> </p> <p> </p>
©️2021 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值