这段代码的主要目的是处理一个数据集中的图像和标签文件,将标签文件中定义的边界框(Bbox)绘制到对应的图像上,并保存这些带有边界框的图像。通过这种方式,可以将数据集中的标签信息可视化,便于进一步的分析和处理。常用于CV中的目标检测数据集的标签质量检查。
这是我们常见是标签格式是yolo的数据集的格式:
目录:
dataset:绝对路径:"C:\Desktop\dataset"
--images:绝对路径:"C:\Desktop\dataset\images"
----.jpg
--labels:绝对路径:"C:\Desktop\dataset\labels"
.....txt
.txt文件的标签格式yolo格式(例子):
3 0.385417 0.498148 0.0625 0.033333
数据集是这样子的,同时给出了标签的格式。labels目录中的.txt中的数据是目标检测的标签数据,每一行数据分别对应一个GT框,每一行数据从左到右分别是:id, x,y,w,h。列1 - 目标类别id , 列2 - 目标中心位置x, 列3 - 目标中心位置y, 列4 - 目标宽度w,列5 - 目标高度h。x,y,w,h是小于1的浮点数,因为是经过对图像进行了归一化处理得到的值,也就是目标的真实的x,w值除以图像的宽度,y,h除以图像的高度。
以下是代码的详细思路和步骤:
-
定义路径:
root_dir
:数据集的根目录路径(需要用户自行指定)。images_dir
:存放原始图像的目录路径。labels_dir
:存放标签文件的目录路径。output_images_plot_dir
:存放绘制了边界框的图像的输出目录路径。
-
创建输出目录:
- 使用
os.makedirs
确保输出目录存在,exist_ok=True
参数表示如果目录已存在则不抛出异常。
- 使用
-
读取并处理标签文件:
- 遍历
labels_dir
目录中的所有文件,筛选出以.txt
结尾的标签文件。 - 对于每个标签文件:
- 读取文件内容,每行数据被分割并去除空白字符,存储为标签列表。
- 将标签文件名中的
.txt
替换为.jpg
,获取对应的图像文件名。
- 遍历
-
读取并检查图像文件:
- 拼接图像文件的完整路径,并检查文件是否存在。
- 如果图像文件不存在,则打印一条消息并跳过当前标签文件。
-
读取图像:
- 使用
cv2.imread
读取图像文件,如果图像未被正确加载(即image is None
),则打印一条消息并跳过当前图像。
- 使用
-
绘制边界框和标签:
- 遍历标签列表中的每个标签,假设每个标签包含5个元素(类别ID,x中心,y中心,宽度,高度)。
- 将归一化的边界框坐标转换为实际的像素坐标。
- 使用
cv2.rectangle
在图像上绘制矩形框,颜色为绿色,线宽为2像素。 - 使用
cv2.putText
在矩形框的左上角绘制类别ID,字体颜色为红色。
-
保存绘制了边界框的图像:
- 拼接输出图像的完整路径,并使用
cv2.imwrite
将图像保存到输出目录。
- 拼接输出图像的完整路径,并使用
-
完成处理:
- 打印一条消息,表示所有图像的处理已完成
import os
import cv2
# 根目录地址
root_dir = #此处给出需要画出Bbox的数据集路径
# 图片目录
images_dir = os.path.join(root_dir, 'images_cf')
# 标签目录
labels_dir = os.path.join(root_dir, 'labels')
# 输出目录,用于存放绘制了标签的图片
output_images_plot_dir = os.path.join(root_dir, 'images_plot')
os.makedirs(output_images_plot_dir, exist_ok=True)
# 读取标签文件并绘制框
for label_file in os.listdir(labels_dir):
if label_file.endswith('.txt'):
# 读取标签文件
with open(os.path.join(labels_dir, label_file), 'r') as f:
labels = [line.strip().split() for line in f.readlines() if line.strip()]
# 读取对应的图片
image_file = label_file.replace('.txt', '.jpg') # 确保是.jpg文件
image_path = os.path.join(images_dir, image_file)
# 确保图像文件存在
if not os.path.isfile(image_path):
print(f"图像文件 {image_path} 不存在,跳过。")
continue
# 读取图像
image = cv2.imread(image_path)
# 确保图像被正确加载
if image is None:
print(f"无法加载图像 {image_path},跳过。")
continue
# 绘制框和标记ID
for label in labels:
if len(label) == 5:
# 解析标签数据
_, x_center, y_center, width, height = map(float, label)
# 将归一化坐标转换为实际坐标
x_min = int((x_center - width / 2) * image.shape[1])
y_min = int((y_center - height / 2) * image.shape[0])
x_max = int((x_center + width / 2) * image.shape[1])
y_max = int((y_center + height / 2) * image.shape[0])
# 绘制矩形框,颜色为绿色
cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
# 提取类别ID,假设第一个元素是ID
id = int(label[0])
# 在框的左上角标记ID,字体颜色为红色
label_text = str(id) # 将ID转换为字符串
text_size, _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
text_x = x_min
text_y = y_min - text_size[1] - 5 # 留出一些空间避免重叠
cv2.putText(image, label_text,
(text_x, text_y),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 0, 255), # 红色文本
1)
# 保存图片到输出目录
output_image_path = os.path.join(output_images_plot_dir, os.path.basename(image_file))
cv2.imwrite(output_image_path, image)
print("所有图片已处理完成。")