目录
本文是对task02的进阶与回顾。
一、数据增强
在task02中也提到过数据增强,那么它的具体含义是:
数据增强是机器学习和深度学习中常用的技术,用于通过从现有数据集中生成新的训练样本来提高模型的泛化能力。干净一致的数据对于创建性能良好的模型至关重要。常见的增强技术包括翻转、旋转、缩放和颜色调整。多个库,例如 Albumentations、Imgaug 和 TensorFlow的 ImageDataGenerator,可以生成这些增强。
数据增强的具体方法:
数据增强方法 | 描述 |
Mosaic Augmentation | 将四张训练图像组合成一张,增加物体尺度和位置的多样性。 |
Copy-Paste Augmentation | 复制一个图像的随机区域并粘贴到另一个图像上,生成新的训练样本。 |
Random Affine Transformations | 包括图像的随机旋转、缩放、平移和剪切,增加对几何变换的鲁棒性。 |
MixUp Augmentation | 通过线性组合两张图像及其标签创造合成图像,增加特征空间的泛化。 |
Albumentations | 一个支持多种增强技术的图像增强库,提供灵活的增强管道定义。 |
HSV Augmentation | 对图像的色相、饱和度和亮度进行随机变化,改变颜色属性。 |
Random Horizontal Flip | 沿水平轴随机翻转图像,增加对镜像变化的不变性。 |
简单来说就是把图像经过旋转,镜像,反转,合成等方法,来增加数据量,提高模型性能的一种手段叫做数据增强。(这是我的理解,不对还请指出)。
二、设置参数
https://docs.ultralytics.com/usage/cfg/#train-settings
YOLO 模型的训练设置包括多种超参数和配置,这些设置会影响模型的性能、速度和准确性。微调涉及采用预先训练的模型并调整其参数以提高特定任务或数据集的性能。该过程也称为模型再训练,使模型能够更好地理解和预测在实际应用中将遇到的特定数据的结果。您可以根据模型评估重新训练模型,以获得最佳结果。
通常,在初始训练时期,学习率从低开始,逐渐增加以稳定训练过程。但是,由于您的模型已经从以前的数据集中学习了一些特征,因此立即从更高的学习率开始可能更有益。在 YOLO 中绝大部分参数都可以使用默认值。(基本敲代码的时候你所用到的软件都会提醒你该怎么输入不用死记硬背,或者直接看英文意思即可)
-
imgsz: 训练时的目标图像尺寸,所有图像在此尺寸下缩放。
-
save_period: 保存模型检查点的频率(周期数),-1 表示禁用。
-
device: 用于训练的计算设备,可以是单个或多个 GPU,CPU 或苹果硅的 MPS。
-
optimizer: 训练中使用的优化器,如 SGD、Adam 等,或 'auto' 以根据模型配置自动选择。
-
momentum: SGD 的动量因子或 Adam 优化器的 beta1。
-
weight_decay: L2 正则化项。
-
warmup_epochs: 学习率预热的周期数。
-
warmup_momentum: 预热阶段的初始动量。
-
warmup_bias_lr: 预热阶段偏置参数的学习率。
-
box: 边界框损失在损失函数中的权重。
-
cls: 分类损失在总损失函数中的权重。
-
dfl: 分布焦点损失的权重。
在YOLOv5及其后续版本中,imgsz
可以被设置为一个整数,用于训练和验证模式,表示将输入图像调整为正方形的尺寸,例如imgsz=640
意味着图像将被调整为640x640像素。对于预测和导出模式,imgsz
可以被设置为一个列表,包含宽度和高度,例如imgsz=[640, 480]
,表示图像将被调整为640像素宽和480像素高。较大的图像尺寸可能会提高模型的准确性,但会增加计算量和内存消耗。较小的图像尺寸可能会降低模型的准确性,但会提高计算速度和内存效率。因此,用户应根据实际场景需求及硬件资源限制,设置合适的输入图像尺寸。
感觉这和大模型没什么区别,模型的优化本质来说就是调参,概率问题。
三、设置性能
YOLO模型的预测结果通常包括多个组成部分,每个部分提供关于检测到的对象的不同信息。同时 YOLO 能够处理包括单独图像、图像集合、视频文件或实时视频流在内的多种数据源,也能够一次性处理多个图像或视频帧,进一步提高推理速度。(适当记记,虽然我也记不住,不会的时候多问问ai也是一种方法)
from ultralytics import YOLO
# Load a model
model = YOLO("yolov8n.pt") # pretrained YOLOv8n model
# Run batched inference on a list of images
results = model(["im1.jpg", "im2.jpg"]) # return a list of Results objects
# Process results list
for result in results:
boxes = result.boxes # Boxes object for bounding box outputs
masks = result.masks # Masks object for segmentation masks outputs
keypoints = result.keypoints # Keypoints object for pose outputs
probs = result.probs # Probs object for classification outputs
obb = result.obb # Oriented boxes object for OBB outputs
result.show() # display to screen
result.save(filename="result.jpg") # save to disk
YOLOv8模型的使用者提供了灵活性,允许根据特定应用场景的需求调整模型的行为和性能。例如,如果需要减少误报,可以提高conf
阈值;如果需要提高模型的执行速度,可以在支持的硬件上使用half
精度;如果需要处理视频数据并希望加快处理速度,可以调整vid_stride
来跳过某些帧。这些参数的适当配置对于优化模型的预测性能至关重要。
四、实际操作
继续之前的代码再次修改,在task02中我进行了数据增强这一操作,但是结果与之前使用YOLOv8m的代码运行出来的结果一模一样,所以对代码进行再次修改与尝试。
重新把代码整合的结果如下:
import os
import cv2
import glob
import pandas as pd
from tqdm import tqdm
from albumentations import (
Compose, HorizontalFlip, Rotate, RandomBrightnessContrast, Resize
)
# 定义类别标签
category_labels = {'非机动车违停': 0, '机动车违停': 1, '垃圾桶满溢': 2, '违法经营': 3}
# 数据增强定义
data_transform = Compose([
HorizontalFlip(p=0.5),
Rotate(limit=10, p=0.5),
RandomBrightnessContrast(p=0.2),
Resize(height=640, width=640, p=1.0)
], bbox_params={'format': 'pascal_voc', 'label_fields': ['category']})
# 确保目录结构存在
os.makedirs('./yolo-dataset/train/images', exist_ok=True)
os.makedirs('./yolo-dataset/train/labels', exist_ok=True)
# 读取标注文件和视频文件的路径
train_annos = glob.glob('训练集(有标注第一批)/标注/*.json')
train_videos = glob.glob('训练集(有标注第一批)/视频/*.mp4')
# 只处理最后三个视频文件及其对应的标注文件
train_annos = train_annos[-3:]
train_videos = train_videos[-3:]
# 确保标注文件和视频文件的数量一致
assert len(train_annos) == len(train_videos), "Number of annotation files does not match number of video files."
# 使用 tqdm 进度条显示进度
for anno_path, video_path in tqdm(zip(train_annos, train_videos), total=len(train_annos)):
print(f"Processing video: {video_path}")
# 读取标注文件
try:
anno_df = pd.read_json(anno_path)
except Exception as e:
print(f"Failed to read annotation file {anno_path}: {e}")
continue
# 打开视频文件
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Failed to open video file {video_path}")
continue
frame_idx = 0
video_name = os.path.splitext(os.path.basename(video_path))[0]
anno_name = os.path.splitext(os.path.basename(anno_path))[0]
while True:
ret, frame = cap.read()
if not ret:
break
img_height, img_width = frame.shape[:2]
# 获取当前帧的标注信息
frame_anno = anno_df[anno_df['frame_id'] == frame_idx]
# 数据增强
bboxes = []
categories = []
if not frame_anno.empty:
for _, row in frame_anno.iterrows():
bboxes.append(row['bbox'])
categories.append(row['category'])
# 应用数据增强
augmented = data_transform(image=frame, bboxes=bboxes, category=categories)
frame = augmented['image']
bboxes = augmented['bboxes']
# 保存图像文件
image_path = f'./yolo-dataset/train/images/{anno_name}_{frame_idx}.jpg'
cv2.imwrite(image_path, frame)
# 如果当前帧有标注信息,则写入标签文件
if bboxes:
label_path = f'./yolo-dataset/train/labels/{anno_name}_{frame_idx}.txt'
with open(label_path, 'w') as up:
for bbox, category in zip(bboxes, categories):
try:
category_idx = category_labels[category]
except KeyError:
print(f"Category '{category}' not found in category labels.")
continue
x_min, y_min, x_max, y_max = bbox
x_center = (x_min + x_max) / 2 / img_width
y_center = (y_min + y_max) / 2 / img_height
width = (x_max - x_min) / img_width
height = (y_max - y_min) / img_height
if x_center > 1 or y_center > 1 or width > 1 or height > 1:
print(f"Bounding box {bbox} exceeds image dimensions.")
up.write(f'{category_idx} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n')
frame_idx += 1
cap.release()
print("处理完成。")
但此时会出现一个问题:
遇到 OSError: [Errno 28] No space left on device 错误表示磁盘空间不足。这意味着你在尝试写入文件时,目标磁盘没有足够的可用空间来存储新的文件。解决这个问题的方法通常有两种:
1、清理磁盘空间:删除不必要的文件或移动文件到其他位置以释放空间。
2、使用其他存储位置:将文件保存到另一个有足够空间的磁盘或存储设备上。
可以修改为:
train_annos = train_annos[:10]
train_videos = train_videos[:10]
这是处理前十个视频,也可以改为前五个,后三个,原来的代码就是分别处理的前五个和后三个。
今天我刚回学校不知道是不是网不好的原因,我那个启动一次就启动不了了,今天就要交笔记了,等我明天再打开看看试试,今天太赶了,明天继续补充。
五、参考资料
-
https://docs.ultralytics.com/modes/predict/#inference-arguments
-
https://docs.ultralytics.com/yolov5/tutorials/test_time_augmentation/