NightOwls和WiderPerson数据集转换为YOLO格式
一、NightOwls数据集
1、NightOwls数据集介绍
论文链接:https://www.robots.ox.ac.uk/~vgg/publications/2018/Neumann18b/neumann18b.pdf。
下载地址:https://www.nightowls-dataset.org/。
NightOwls数据集中的图片如下:
包含4个类别:行人、自行车驾驶员、摩托车驾驶员、忽略区域(Pedestrians、Bicycledriver、Motorbikedriver、Ignore areas)
数据来源:
来自了3个国家和多个欧洲城市
4个季节
黎明和夜晚
所有天气条件
大量动态对象
变化的场景布局
变化的背景
逼真的质量取决于场景照明和车速,包括模糊和清晰的图像
图像分辨率1024 x 640
以16fps的原始帧速率每隔三帧拍摄一次,帧速率为5,33fps
2、json格式展示
本项目只下载了验证集,因为训练集有100多G,下载起来比较慢。
下载的数据标注格式为json格式,以下列举了部分内容:
{
"images":[{
"height": 640,
"width": 1024,
"daytime": "night",
"file_name": "58c5832dbc2601370015a128.png",
"id": 7025712,
"recordings_id": 36,
"timestamp": 6159653184
},
{
"height": 640,
"width": 1024,
"daytime": "night",
"file_name": "58c5832dbc2601370015a129.png",
"id": 7025713,
"recordings_id": 36,
"timestamp": 6159713463
}
],
"annotations":[{
"occluded": null,
"difficult": false,
"bbox": [
707,
226,
45,
110
],
"id": 7025712,
"category_id": 1,
"image_id": 7057530,
"pose_id": 3,
"tracking_id": 7001819,
"ignore": 0,
"area": 4950,
"truncated": false
},
{
"occluded": null,
"difficult": null,
"bbox": [
755,
241,
264,
142
],
"id": 7025713,
"category_id": 4,
"image_id": 7057530,
"pose_id": 5,
"tracking_id": 7001820,
"ignore": 1,
"area": 37488,
"truncated": false
}
],
"categories": [
{
"name": "pedestrian",
"id": 1
},
{
"name": "bicycledriver",
"id": 2
},
{
"name": "motorbikedriver",
"id": 3
},
{
"name": "ignore",
"id": 4
}
],
"poses": [
{
"name": "front",
"id": 0
},
{
"name": "left",
"id": 1
},
{
"name": "back",
"id": 2
},
{
"name": "right",
"id": 3
},
{
"name": "nan",
"id": 4
}
]
}
3、标签格式转换(json格式转txt格式)
通过以下代码将json格式转换成yolo的txt格式:
import json
import os
from tqdm import tqdm
def yolo_trans(result_save_path, json_data, class_num):
# 创建一个字典来存储每个图像的标注
annotations_dict = {}
for annotation in json_data['annotations']:
if annotation['ignore'] == 0: # 忽略标记为忽略的标注
image_id = annotation['image_id']
if image_id not in annotations_dict:
annotations_dict[image_id] = []
annotations_dict[image_id].append(annotation)
# 为每个图像生成YOLO格式的txt文件
for image in tqdm(json_data['images']):
image_id = image['id']
if image_id in annotations_dict:
filename = image['file_name'].split('.')[0] + '.txt'
filename = os.path.join(result_save_path, filename)
with open(filename, 'w') as f:
for annotation in annotations_dict[image_id]:
x_center = (annotation['bbox'][0] + annotation['bbox'][2] / 2) / image['width']
y_center = (annotation['bbox'][1] + annotation['bbox'][3] / 2) / image['height']
width = annotation['bbox'][2] / image['width']
height = annotation['bbox'][3] / image['height']
if class_num == 3:
class_id = annotation['category_id'] - 1
elif class_num == 1:
class_id = 0
f.write(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
print("转换完成")
def main():
json_data_path = 'D:\chb_work\workspace2\cv\data\person\\NightOwls\Validation\\nightowls_validation.json'
result_save_path = 'D:\chb_work\workspace2\cv\data\person\\NightOwls\Validation\yolo_label'
os.makedirs(result_save_path, exist_ok=True)
class_num = 3
# 解析JSON数据
with open(json_data_path, 'r') as file:
json_data = json.load(file)
yolo_trans(result_save_path, json_data, class_num)
if __name__ == "__main__":
main()
通过以上代码可以将标注分为3个类别:Pedestrians、Bicycledriver、Motorbikedriver,其中Ignore areas是我们不需要的类别。
如果想将Pedestrians、Bicycledriver和Motorbikedriver合并为一个类别person,将以上代码的class_num = 3改为class_num = 1即可。在本项目中使用的是class_num = 1。
4、yolo标签展示
通过以下代码查看数据是否标注正确:
# 通过标签可视化图片以及标签
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
import os
# 定义分类名称
class_names = {
0: 'person'
}
# 定义不同类别对应的颜色
class_colors = {
0: (0, 0, 255), # 红色
}
imgs_path = 'D:\chb_work\workspace2\cv\data\person\\NightOwls\Validation\\nightowls_validation'
labels_path = 'D:\chb_work\workspace2\cv\data\person\\NightOwls\Validation\yolo_label'
save_path = 'D:\chb_work\workspace2\cv\data\person\\NightOwls\Validation\yolo_person_vis'
if not os.path.exists(save_path):
os.makedirs(save_path)
files_list = os.listdir(labels_path)
for file_name in tqdm(files_list):
label_path = os.path.join(labels_path, file_name)
img_name = file_name.split('.')[0] + '.png'
img_path = os.path.join(imgs_path, img_name)
image_pil = Image.open(img_path)
# 将PIL图像转为NumPy数组
image_np = np.array(image_pil)
# 将NumPy数组转为OpenCV格式
image = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
# 读取标签文件,每行格式为:class x_center y_center width height
# 例如:0 0.5 0.5 0.2 0.3
with open(label_path, 'r') as f:
labels = f.read().strip().split('\n')
if labels[0]=='':
continue
for label in labels:
label = label.split()
class_id = int(label[0])
x_center, y_center, width, height = map(float, label[1:])
# 计算目标框的左上角和右下角坐标
x1 = int((x_center - width / 2) * image.shape[1])
y1 = int((y_center - height / 2) * image.shape[0])
x2 = int((x_center + width / 2) * image.shape[1])
y2 = int((y_center + height / 2) * image.shape[0])
# 绘制目标框
color = class_colors[class_id] # 框的颜色
thickness = 2 # 框的粗细
cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)
# 可以在框上方标注类别名
font = cv2.FONT_HERSHEY_SIMPLEX
text = class_names[class_id]
# text = f'Class: {class_id}'
org = (x1, y1 - 3) # 标注的位置
cv2.putText(image, text, org, font, 0.7, color, thickness, cv2.LINE_AA)
# 保存/显示图像
cv2.imwrite(os.path.join(save_path, img_name), image)
运行完成后,随机查看几张图片,可以发现标注的是正确的:
5、查看标签个数
通过以下代码查看标签的个数:
import os
from tqdm import tqdm
def count_labels_in_txt_folder(folder_path):
# 创建一个字典,用于存储每个类别的计数
label_counts = {}
# 获取文件夹中所有的txt文件
txt_files = [f for f in os.listdir(folder_path) if f.endswith('.txt')]
# 遍历每个txt文件
for txt_file in tqdm(txt_files):
txt_file_path = os.path.join(folder_path, txt_file)
# 打开txt文件并逐行读取内容
with open(txt_file_path, 'r') as file:
lines = file.readlines()
# 遍历每一行
for line in lines:
label_data = line.split()
if len(label_data) >= 1:
class_id = int(label_data[0])
# 检查类别是否已经在字典中,如果不在,则初始化计数为0
if class_id not in label_counts:
label_counts[class_id] = 0
# 增加对应类别的计数
label_counts[class_id] += 1
return label_counts
# 指定包含txt文件的文件夹路径
folder_path = 'D:\chb_work\workspace2\cv\data\person\\NightOwls\Validation\yolo_label'
# 调用函数以获取标签计数
label_counts = count_labels_in_txt_folder(folder_path)
# 打印每个类别的计数
for class_id, count in label_counts.items():
print(f'Class {class_id}: {count} instances')
代码运行完成后,可以发现一共有9702个标签。
6、图片和标签匹配
可以发现转换后的标注文件只有6595个,而图片数据有51848张,此时标注文件和图片数量是不匹配的。
如果使用全部的图片数据,负样本就太多了,因此我们除了选择6595张正样本,又在剩余的图片中随机抽取了5000张作为负样本,并生成这些负样本的空txt标签文件,除此之外,还将图片的格式从png转换为jpg。实现代码如下:
import os
import shutil
from tqdm import tqdm
import random
def move_img(orig_imgs_path, labels_path, dest_imgs_path):
files_list = os.listdir(labels_path)
for file_name in tqdm(files_list):
img_name = file_name.split('.')[0] + '.png'
orig_img_path = os.path.join(orig_imgs_path, img_name)
dest_img_path = os.path.join(dest_imgs_path, file_name.split('.')[0] + '.jpg')
shutil.copy(orig_img_path, dest_img_path)
def select_img(orig_imgs_path, labels_path, dest_imgs_path, select_num):
# 获取原始文件夹中的所有文件列表
files_in_orig = os.listdir(orig_imgs_path)
# 获取目标文件夹中的所有文件列表
files_in_dest = os.listdir(dest_imgs_path)
# 计算目标文件夹中不存在的文件列表
files_to_move = set(files_in_orig) - set(files_in_dest)
files_to_move = list(files_to_move)
random.shuffle(files_to_move)
files_to_move = files_to_move[:select_num]
# 移动文件
for file_name in tqdm(files_to_move):
src_file_path = os.path.join(orig_imgs_path, file_name)
dst_file_path = os.path.join(dest_imgs_path, file_name.split('.')[0] + '.jpg')
shutil.copy(src_file_path, dst_file_path)
# 生成空的txt标签文件
txt_name = file_name.split('.')[0] + '.txt'
txt_path = os.path.join(labels_path, txt_name)
open(txt_path, 'w').close()
def main():
orig_imgs_path = 'D:\chb_work\workspace2\cv\data\person\\NightOwls\Validation\\nightowls_validation'
labels_path = 'D:\chb_work\workspace2\cv\data\person\\NightOwls\Validation\yolo_label'
dest_imgs_path = 'D:\chb_work\workspace2\cv\data\person\\NightOwls\Validation\\train_image'
if not os.path.exists(dest_imgs_path):
os.makedirs(dest_imgs_path)
# 将标签所对应的图片移动到image文件夹
move_img(orig_imgs_path, labels_path, dest_imgs_path)
# 在剩余的图片中随机选择5000张移动到image文件夹,并生成空的txt标签文件
select_num = 5000
select_img(orig_imgs_path, labels_path, dest_imgs_path, select_num)
if __name__ == '__main__':
main()
通过以上操作,最终得到了两个文件夹train_image和yolo_label,一共11595条数据。
二、WiderPerson数据集
1、WiderPerson数据集介绍
下载地址:http://www.cbsr.ia.ac.cn/users/sfzhang/WiderPerson/。
WiderPerson数据集是一个行人检测数据集,其中的图像是从广泛的场景中选择的,不只局限于交通场景。该数据集一共13382张图像,对其中9000张图像进行了标注。
2、标签格式转换
该数据集每张图片的标注如下:
< 此图像中的注释数量 = N >
< anno 1 >
< anno 2 >
...
< anno N >
其中每行一个对象实例是 [class_label , x1, y1, x2, y2],其中x1,y1表示左上角坐标,x2,y2表示右下角坐标。
类标签定义为:
< class_label =1:行人 >
< class_label =2:骑手 >
< class_label =3:部分可见人员 >
< class_label =4:忽略区域 >
< class_label =5:人群 >
通过以下代码可视化部分图片的标注:
# 通过标签可视化图片以及标签
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
import os
# 定义分类名称
class_names = {
1: '1',
2: '2',
3: '3',
4: '4',
5: '5'
}
# 定义不同类别对应的颜色
class_colors = {
1: (0, 0, 255),
2: (0, 255, 0),
3: (255, 0, 0),
4: (255, 0, 255),
5: (0, 255, 255),
}
imgs_path = 'D:\chb_work\workspace2\cv\data\person\WiderPerson\\tmp\images'
labels_path = 'D:\chb_work\workspace2\cv\data\person\WiderPerson\\tmp\labels'
save_path = 'D:\chb_work\workspace2\cv\data\person\WiderPerson\\tmp\\box_vis'
if not os.path.exists(save_path):
os.makedirs(save_path)
files_list = os.listdir(labels_path)
for file_name in tqdm(files_list):
label_path = os.path.join(labels_path, file_name)
img_name = file_name.split('.')[0] + '.jpg'
img_path = os.path.join(imgs_path, img_name)
image_pil = Image.open(img_path)
# 将PIL图像转为NumPy数组
image_np = np.array(image_pil)
# 将NumPy数组转为OpenCV格式
image = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
# 读取标签文件,每行格式为:class x_center y_center width height
# 例如:0 0.5 0.5 0.2 0.3
with open(label_path, 'r') as f:
labels = f.read().strip().split('\n')
if labels[0]=='':
continue
# 跳过第一行,从第二行开始处理
labels = labels[1:]
for label in labels:
label = label.split()
class_id = int(label[0])
x1, y1, x2, y2 = map(int, label[1:])
# 绘制目标框
color = class_colors[class_id] # 框的颜色
thickness = 1 # 框的粗细
cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)
# 可以在框上方标注类别名
font = cv2.FONT_HERSHEY_SIMPLEX
text = class_names[class_id]
# text = f'Class: {class_id}'
org = (x1, y1 - 3) # 标注的位置
cv2.putText(image, text, org, font, 0.7, color, thickness, cv2.LINE_AA)
# 保存/显示图像
cv2.imwrite(os.path.join(save_path, img_name), image)
随机打开两张图片进行查看:
可以发现图片中的人非常的多,因为我们本项目主要做人员检测,所以把第4类(忽略区域)和第5类(人群)标注去掉,1、2、3类归并成同一类,并将标注文件转换为yolo的格式(坐标格式是<class_id> <x_center> <y_center> ),实现代码如下:
# 将标签转换为yolo格式
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
import os
imgs_path = 'D:\chb_work\workspace2\cv\data\person\WiderPerson\Images'
labels_path = 'D:\chb_work\workspace2\cv\data\person\WiderPerson\Annotations'
save_path = 'D:\chb_work\workspace2\cv\data\person\WiderPerson\yolo_label'
if not os.path.exists(save_path):
os.makedirs(save_path)
files_list = os.listdir(labels_path)
for file_name in tqdm(files_list):
label_path = os.path.join(labels_path, file_name)
img_name = file_name.split('.')[0] + '.jpg'
img_path = os.path.join(imgs_path, img_name)
image_pil = Image.open(img_path)
# 将PIL图像转为NumPy数组
image_np = np.array(image_pil)
# 将NumPy数组转为OpenCV格式
image = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
# 读取标签文件,每行格式为:class x_center y_center width height
# 例如:0 0.5 0.5 0.2 0.3
with open(label_path, 'r') as f:
lines = f.read().strip().split('\n')
if lines[0]=='':
continue
# 跳过第一行,因为它表示标注的个数
lines = lines[1:]
# 准备一个新的列表来存储YOLO格式的标注
yolo_annotations = []
# 遍历每一行标注
for line in lines:
# 分割每行的数据
data = line.strip().split()
# 检查类别标签是否为1、2或3
if len(data) == 5 and int(data[0]) in [1, 2, 3]:
class_id = 0
x_min, y_min, x_max, y_max = map(int, data[1:])
# 计算YOLO格式的坐标
x_center = (x_min + x_max) / 2 / image.shape[1]
y_center = (y_min + y_max) / 2 / image.shape[0]
width = (x_max - x_min) / image.shape[1]
height = (y_max - y_min) / image.shape[0]
# 添加到YOLO格式的标注列表
yolo_annotations.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
# 将YOLO格式的标注写入新的文件
with open(os.path.join(save_path, file_name.split('.')[0]+'.txt'), 'w') as file:
for annotation in yolo_annotations:
file.write(annotation + '\n')
通过以上代码就将标注文件转换成了yolo格式,只保留了1、2、3这三个类别,并将这三个类别归并成了一个类别person。
3、yolo标签展示
通过以下代码查看yolo格式的标注是否正确:
# 通过标签可视化图片以及标签
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
import os
# 定义分类名称
class_names = {
0: '0'
}
# 定义不同类别对应的颜色
class_colors = {
0: (0, 0, 255), # 红色
}
imgs_path = 'D:\chb_work\workspace2\cv\data\person\WiderPerson\\tmp\images'
labels_path = 'D:\chb_work\workspace2\cv\data\person\WiderPerson\\tmp\labels'
save_path = 'D:\chb_work\workspace2\cv\data\person\WiderPerson\\tmp\\box_vis'
if not os.path.exists(save_path):
os.makedirs(save_path)
files_list = os.listdir(labels_path)
for file_name in tqdm(files_list):
label_path = os.path.join(labels_path, file_name)
img_name = file_name.split('.')[0] + '.jpg'
img_path = os.path.join(imgs_path, img_name)
image_pil = Image.open(img_path)
# 将PIL图像转为NumPy数组
image_np = np.array(image_pil)
# 将NumPy数组转为OpenCV格式
image = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
# 读取标签文件,每行格式为:class x_center y_center width height
# 例如:0 0.5 0.5 0.2 0.3
with open(label_path, 'r') as f:
labels = f.read().strip().split('\n')
if labels[0]=='':
continue
for label in labels:
label = label.split()
class_id = int(label[0])
x_center, y_center, width, height = map(float, label[1:])
# 计算目标框的左上角和右下角坐标
x1 = int((x_center - width / 2) * image.shape[1])
y1 = int((y_center - height / 2) * image.shape[0])
x2 = int((x_center + width / 2) * image.shape[1])
y2 = int((y_center + height / 2) * image.shape[0])
# 绘制目标框
color = class_colors[class_id] # 框的颜色
thickness = 1 # 框的粗细
cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)
# 可以在框上方标注类别名
font = cv2.FONT_HERSHEY_SIMPLEX
text = class_names[class_id]
# text = f'Class: {class_id}'
org = (x1, y1 - 3) # 标注的位置
cv2.putText(image, text, org, font, 0.7, color, thickness, cv2.LINE_AA)
# 保存/显示图像
cv2.imwrite(os.path.join(save_path, img_name), image)
随机选择两张进行查看,发现是转换的是正确的:
4、查看标签个数
通过以下代码查看标签的个数:
import os
from tqdm import tqdm
def count_labels_in_txt_folder(folder_path):
# 创建一个字典,用于存储每个类别的计数
label_counts = {}
# 获取文件夹中所有的txt文件
txt_files = [f for f in os.listdir(folder_path) if f.endswith('.txt')]
# 遍历每个txt文件
for txt_file in tqdm(txt_files):
txt_file_path = os.path.join(folder_path, txt_file)
# 打开txt文件并逐行读取内容
with open(txt_file_path, 'r') as file:
lines = file.readlines()
# 遍历每一行
for line in lines:
label_data = line.split()
if len(label_data) >= 1:
class_id = int(label_data[0])
# 检查类别是否已经在字典中,如果不在,则初始化计数为0
if class_id not in label_counts:
label_counts[class_id] = 0
# 增加对应类别的计数
label_counts[class_id] += 1
return label_counts
# 指定包含txt文件的文件夹路径
folder_path = 'D:\chb_work\workspace2\cv\data\person\WiderPerson\yolo_label'
# 调用函数以获取标签计数
label_counts = count_labels_in_txt_folder(folder_path)
# 打印每个类别的计数
for class_id, count in label_counts.items():
print(f'Class {class_id}: {count} instances')
代码运行完成后,可以发现一共有260186个标签。
5、图片和标签匹配
可以发现标注文件只有9000个,而图片数据有13382张,这是因为该数据集只标注了9000张图像,其余的没有标注,因此我们需要选择有标注的9000张图片来训练模型。实现代码如下:
import os
import shutil
from tqdm import tqdm
import random
def move_img(orig_imgs_path, labels_path, dest_imgs_path):
files_list = os.listdir(labels_path)
for file_name in tqdm(files_list):
img_name = file_name.split('.')[0] + '.jpg'
orig_img_path = os.path.join(orig_imgs_path, img_name)
dest_img_path = os.path.join(dest_imgs_path, img_name)
shutil.copy(orig_img_path, dest_img_path)
def main():
orig_imgs_path = 'D:\chb_work\workspace2\cv\data\person\WiderPerson\Images'
labels_path = 'D:\chb_work\workspace2\cv\data\person\WiderPerson\yolo_label'
dest_imgs_path = 'D:\chb_work\workspace2\cv\data\person\WiderPerson\\train_image'
if not os.path.exists(dest_imgs_path):
os.makedirs(dest_imgs_path)
# 将标签所对应的图片移动到image文件夹
move_img(orig_imgs_path, labels_path, dest_imgs_path)
if __name__ == '__main__':
main()
通过以上操作,最终得到了两个文件夹train_image和yolo_label,一共9000条数据。
三、将NightOwls数据集和WiderPerson数据集合并后划分数据集
接下来需要合并NightOwls数据集和WiderPerson数据集,并将其划分为训练集、验证集和测试集,实现的代码如下:
import os
import random
import shutil
from tqdm import tqdm
def split_data(images_path, labels_path, dest_data_dir):
train_ratio = 0.8 # 训练集比例
val_ratio = 0.1 # 验证集比例
test_ratio = 0.1 # 测试集比例
# 创建训练集、验证集和测试集目录
train_dir = os.path.join(dest_data_dir, "train")
val_dir = os.path.join(dest_data_dir, "valid")
test_dir = os.path.join(dest_data_dir, "test")
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
# 创建训练集、验证集和测试集的图像和标签文件夹
os.makedirs(os.path.join(train_dir, "images"), exist_ok=True)
os.makedirs(os.path.join(train_dir, "labels"), exist_ok=True)
os.makedirs(os.path.join(val_dir, "images"), exist_ok=True)
os.makedirs(os.path.join(val_dir, "labels"), exist_ok=True)
os.makedirs(os.path.join(test_dir, "images"), exist_ok=True)
os.makedirs(os.path.join(test_dir, "labels"), exist_ok=True)
# 获取所有图像文件的列表
image_files = os.listdir(images_path)
# 随机打乱图像文件列表
random.shuffle(image_files)
# 计算划分的索引
num_images = len(image_files)
num_train = int(num_images * train_ratio)
num_val = int(num_images * val_ratio)
# 分割数据集并将文件复制到对应目录
for i, image_file in enumerate(tqdm(image_files)):
if i < num_train:
shutil.copy(os.path.join(images_path, image_file),
os.path.join(train_dir, "images", image_file))
shutil.copy(os.path.join(labels_path, image_file.replace(".jpg", ".txt")),
os.path.join(train_dir, "labels", image_file.replace(".jpg", ".txt")))
elif i < num_train + num_val:
shutil.copy(os.path.join(images_path, image_file), os.path.join(val_dir, "images", image_file))
shutil.copy(os.path.join(labels_path, image_file.replace(".jpg", ".txt")),
os.path.join(val_dir, "labels", image_file.replace(".jpg", ".txt")))
else:
shutil.copy(os.path.join(images_path, image_file), os.path.join(test_dir, "images", image_file))
shutil.copy(os.path.join(labels_path, image_file.replace(".jpg", ".txt")),
os.path.join(test_dir, "labels", image_file.replace(".jpg", ".txt")))
def main():
# NightOwl数据集目录
NightOwls_images_path = "D:\chb_work\workspace2\cv\data\person\\NightOwls\Validation\\train_image"
NightOwls_labels_path = "D:\chb_work\workspace2\cv\data\person\\NightOwls\Validation\\yolo_label"
# WiderPerson数据集目录
WiderPerson_images_path = "D:\chb_work\workspace2\cv\data\person\WiderPerson\\train_image"
WiderPerson_labels_path = "D:\chb_work\workspace2\cv\data\person\WiderPerson\yolo_label"
# 存放划分完成的数据集目录
dest_data_dir = "D:\chb_work\workspace2\cv\data\person\\20240403person_data" # 目标数据集根目录,包含图像文件夹和标签文件夹
os.makedirs(dest_data_dir, exist_ok=True)
split_data(NightOwls_images_path, NightOwls_labels_path, dest_data_dir)
split_data(WiderPerson_images_path, WiderPerson_labels_path, dest_data_dir)
if __name__ == '__main__':
main()
划分完成后,一共得到训练集19578张,验证集2447张,测试集2448张。划分完成后的文件目录如下:
接下来需要在同级目录下新建yaml文件,文件内容如下:
train: ../20240403person_data/train/images
val: ../20240403person_data/valid/images
nc: 1
names: ['person']
最终的目录结构如下:
之后就可以利用该数据集训练yolo模型了。