一、背景
最近要参加一个比赛,比赛是做图像识别的,比赛提供了部分BDD100K数据集,每个图片是1280*720的交通图片,每个标签中带有物体类别、物体左上锚点、物体右下锚点,如下:
比赛提供的标签数据有9种目标,如下:
比赛前的培训课程上,官方演示了使用YOLO V5模型训练了部分数据,因此我准备使用更高版本的YOLO V10模型训练完整的BDD100K数据集
我的环境信息如下:
- 操作系统windows11
- 显卡为NVIDIA GeForce RTX 3080
- cuda版本11.7
- cudnn版本8.9.7.29
- torch版本2.0.1+cu117
- torchvision版本0.15.0+cu117
二、准备数据集与环境
1. 下载BDD100K数据集
我忘记我的数据集是从哪里下载的了,可以参照BDD100K官方文档去下载,不过我下载的数据集和官网上的路径不大一致,但是差别不大,可以用
我下载下来的数据集路径如下:
images是图片文件路径,里面分为10k和100k,10k的数据体量小,我这里要训练100k的数据,test、train、val分别是测试集、训练集和验证集,里面都是jpg格式的图片数据
labels是标签文件路径,我下载的这里只有100k的,是不是10k的标签在100k里都能找到我没有去验证,对应这images/100k里的路径,labels/100l里有训练集和验证集的标签文件,都是json格式的,文件名与图片一一对应,如下:
标签json文件格式如下:
需要注意的是,BDD100K的数据标签里objects列表后面有一些内容长这样:
我没有去深挖这个标签代表什么,有兴趣可以去研究,因为我在这个项目中用不到这个标签,因此后面做标签转换的时候略过了这些标签。
2. 下载YOLO V10代码
从YOLO V10的Github主页上下载代码,有git的用git,没有git直接下载zip即可
下载完成后,把YOLO v10的代码放到项目路径中,我的路径如图:
3. 标签转换
BDD100K的数据标注方式与YOLO v10的标注方式不同,BDD100K是用左上和右下两个锚点的标注方式,而YOLO V10则是是用的中心点标注方法,例如:
在上图这个例子里,图片中有三个物体,两个人和一条领带,yolo模型的标注文件是txt文件,一行一条标注信息
我们以图片上的齐达内为例,他的标注有5个信息,分别是类别、中心点x相对位置,中心点y相对位置,图像高百分比height,图像宽百分比height
yolo模型不管图片的像素尺寸是多少,图片的长和宽都是0到1的范围,里面的每个物体也是一个矩形,x=0.48表示齐达内的中心点在图像x轴上的位置为0.48,y=0.63表示齐达内的中心点在图像y轴的0.63,height=0.71表示齐达内的高占整个图片高的71%,width=0.69表示齐达内的宽占整个图片宽的69%,应该还是很好理解的。
这种标注相比于BDD100K的标注,好处在于更加通用,如果BDD100K的图像被缩放,不再是1280*720的尺寸,那么所有的标注锚点坐标都需要对应修改
所以我们需要先编写一个python脚本,来吧json格式的BDD100K标注文件转化成txt格式的YOLO v10标注文件,基本思路如下
假设我们有一张图片,大小是1280*720,在图片内部有一个标签为car的物体,左上角锚点的坐标是(100, 120),右下角的锚点坐标是(200, 320),那么:
- 中心点坐标
X = ((200 - 100) / 2 + 100) / 1280 = 150 / 1280 = 0.1171875
- 中心点坐标
Y = ((320 - 120) / 2 + 120) / 720 = 440 / 720 = 0.3055556
- 物体宽度
width = (200 - 100) / 1280 = 0.078125
- 物体高度
height = (320 - 120) / 720 = 0.277778
简化一下,得到公式:
- center_X = (x1 + x2) / 2 / 1280
- center_Y = (y1 + y2) / 2 / 720
- width = (x2 - x1) / 1280
- height = (y2 - y1) / 1280
根据上面的公示编写脚本就可以了,需要注意的是要把训练集的标签和验证集的标签都进行转换,我的转换脚本名为转换标签.py
,路径在cs目录中:
代码如下:
import json
import os
# 输入和输出路径
# input_json_path = '../bdd100k/labels/100k/train'
# output_txt_dir = '../bdd100k/labels/100k/train'
input_json_path = '../bdd100k/labels/100k/val'
output_txt_dir = '../bdd100k/labels/100k/val'
# 类别映射,可以根据需要修改
category_map = {
'bus': 0,
'traffic light': 1,
'traffic sign': 2,
'person': 3,
'bike': 4,
'truck': 5,
'motor': 6,
'car': 7,
'rider': 8,
# 添加其他类别
}
category_keys_list = list(category_map.keys()) # 所有需要识别的物体名称列表
# 确保输出目录存在
os