在安全帽佩戴检测数据集训练YOLOv5--数据集处理
参考链接
SafetyHelmetWearing-Dataset(安全帽佩戴检测数据集)
Train Custom Data(YOLOv5 训练自定义数据集)
数据集
因为我是直接在Google Colab上训练的,所以直接打开第二个链接添加到云端硬盘,解压即可;
如果是无法访问Google,那就百度盘下载吧!
挂载谷歌云端硬盘:
from google.colab import drive
drive.mount("/gdrive")
%cd "/gdrive/My Drive/YOLOv5"
数据集解包:
!unzip -q VOC2028.zip -d ./
这是我的大致文件结构:
YOLOv5/ # 我自己新建的主目录
yolov5/ # 官方git 克隆下来的
VOC2028/ # 这是数据集解压后的文件夹
ImageSets/Main/ # train、test、val的分割txt
Annotations/ # xml标注文件
JPEGImages/ # 图片
SHWD/ # 这是处理后的文件夹,用来训练
images/ # 图片
train/
val/
test/
labels/ # 用于yolo的txt标签
train/
val/
test/
解释一下,原来的数据集是xml格式的标注数据,不能直接用于yolo训练,需要转为txt。
train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
yolov5中支持这3种格式的数据集,我这里使用的是第一种。
处理数据集
将原来的数据集VOC2028/
转为需要的格式SHWD/
:
- 我喜欢先将目录结构建好:
import os
train_img_dir = "SHWD/images/train"
val_img_dir = "SHWD/images/val"
test_img_dir = "SHWD/images/test"
train_label_dir = "SHWD/labels/train"
val_label_dir = "SHWD/labels/val"
test_label_dir = "SHWD/labels/test"
os.makedirs(train_img_dir)
os.makedirs(val_img_dir)
os.makedirs(test_img_dir)
os.makedirs(train_label_dir)
os.makedirs(val_label_dir)
os.makedirs(test_label_dir)
- 接着转换格式并将图片和标签放到对应文件:
import os, sys
import shutil
from tqdm import tqdm
import xml.etree.ElementTree as ET
classes = ["hat", "person"]
sets = ["train", "val", "test"]
def convert(size, box):
dw = 1. / size[0]
dh = 1. / size[1]
x = (box[0] + box[1]) / 2.0
y = (box[2] + box[3]) / 2.0
w = box[1] - box[0]
h = box[3] - box[2]
x = x * dw # x_center
y = y * dh # y_center
w = w * dw # width
h = h * dh # height
return (x, y, w, h)
def parse_xml(xml_path, dst_label_path):
anno_xml = xml_path
anno_txt = dst_label_path
if os.path.exists(anno_xml):
xml = open(anno_xml, "r")
txt = open(anno_txt, "w")
tree = ET.parse(xml)
root = tree.getroot()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
for obj in root.iter('object'):
cls = obj.find('name').text
difficult = obj.find('difficult').text
if cls not in classes or difficult == 1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
bbox = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
float(xmlbox.find('ymax').text))
yolo_bbox = convert((w, h), bbox)
yolo_anno = str(cls_id) + " " + " ".join([str(i) for i in yolo_bbox]) + '\n'
txt.write(yolo_anno)
xml.close()
txt.close()
else:
print(anno_xml, "文件不存在")
def copy_to(src, dst):
shutil.copyfile(src, dst)
for s in sets:
name_path = "VOC2028/ImageSets/Main/{}.txt".format(s)
f = open(name_path, "r")
names = f.readlines()
f.close()
for name in tqdm(names):
name = name.replace('\n', '').replace('\r', '')
image_path = r"VOC2028/JPEGImages/{}.jpg".format(name)
xml_path = r"VOC2028/Annotations/{}.xml".format(name)
dst_image_path = r"SHWD/images/{}/{}.jpg".format(s, name)
dst_label_path = r"SHWD/labels/{}/{}.txt".format(s, name)
if os.path.exists(image_path) and os.path.exists(xml_path):
parse_xml(xml_path, dst_label_path)
if not os.path.exists(dst_image_path):
copy_to(image_path, dst_image_path)
else:
print(dst_image_path, "文件已存在")
else:
print(image_path, xml_path)
嗯,是的!代码没有注释
标签中的数据除了第一个值(类别)是整数外,其他四个值均为浮点数
txt格式:
每张图片对应一个txt
每个目标一行,整个图片没有目标的话不需要有txt文件
每行的格式为class_num x_center y_center width height
其中class_num取值为0至total_class - 1,
框的四个值x_center y_center width height是相对于图片分辨率大小正则化的0-1之间的数,左上角为(0,0),右下角为(1,1)
一图胜千言,看图说话
最终,数据集的结果就酱啊:
SHWD
images
train
***.jpg
val
***.jpg
test
***.jpg
labels
train
***.txt
val
***.txt
test
***.txt
数据相关可视化
这是在yolo中的代码,在训练时,会通过seaborn的pairplot将数据集的一些信息画出来。
def plot_labels(labels, save_dir=Path(''), loggers=None):
# plot dataset labels
print('Plotting labels... ')
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
nc = int(c.max() + 1) # number of classes
colors = color_list()
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
# seaborn correlogram
sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
plt.close()
4个柱状图是分别对[‘x’, ‘y’, ‘width’, ‘height’]的统计
scanning images
这部分是另外一个数据集的,代码和文件结构不太一样,可以参考
发现scanning images很慢可能的原因是数据命名格式化问题。之前数据名字是乱的,没有规律的,scanning images时就很慢,基本上是1.7it/s,现在重新命名之后,一下就完成了,能达到几十到一百多。
images_path = "./road-damage/images"
labels_path = "./road-damage/labels"
images = os.listdir(images_path)
for idx, image in enumerate(images, 1):
name, ext = os.path.splitext(image)
if ext in [".jpg", ".jpeg", ".png"]:
new_name = str(idx).zfill(6) # 000000 ---999999这种形式
src_image = os.path.join(images_path, image)
dst_image = os.path.join(images_path, new_name + ".jpg")
src_label = os.path.join(labels_path, name + ".txt")
dst_label = os.path.join(labels_path, new_name + ".txt")
if os.path.exists(src_image) and os.path.exists(src_label):
print(src_image, dst_image, src_label, dst_label)
os.rename(src_image, dst_image)
os.rename(src_label, dst_label)