#Datawhale #AI夏令营 #针对城市管理中违规行为的智能识别系统——YOLO解决方案 (2)

系列文章目录

Task1: 第一篇请点击这里
Task2 (loading…)
Task3 (loading…)


Task 2


前言

这篇文章延续上次对跑通Baseline的分享与讲解,将对建模方案进行深入解读,并初步探讨进阶方案。

本篇文章主要是记录和分享自己在夏令营中的学习过程和遇到的困难


一、对当前模型的分析与优化建议

角度 0: 优点

详细代码与讲解在上一篇文章:Task 1
Baseline的模型在以下几个方面比较优秀:

  1. 实现了完整的数据处理流程,包括数据下载、预处理、目标检测模型训练和测试结果生成
  2. 使用YAML文件定义了数据集路径和类别信息,使得数据集管理更加方便
  3. 代码结构清晰,分步实现了数据集的划分、标注文件的处理和模型训练等功能

但它也存在一些问题,导致模型效果不够理想,以下是我自己的分析与建议,仅供参考,大家也可以结合Datawhale的官方文件来考虑进阶方案:https://linklearner.com/activity/16/16/68

角度 1: 数据读取

1.1 存在硬编码路径的问题,建议将路径信息提取为变量,增加代码的灵活性

# 定义路径变量
train_annotation_path = '训练集(有标注第一批)/标注/'
train_video_path = '训练集(有标注第一批)/视频/'

角度 2: 数据处理

2.1 处理训练集和验证集的代码逻辑几乎完全相同,只是输出路径不同,因此可以将这部分逻辑封装成一个函数,提高代码的复用性和可维护性

# 代码示例
def process_dataset(annos, videos, output_dir):
    for anno_path, video_path in zip(annos, videos):
        print(video_path)
        anno_df = pd.read_json(anno_path)
        cap = cv2.VideoCapture(video_path)
        frame_idx = 0 
        while True:
            ret, frame = cap.read()
            if not ret:
                break

            img_height, img_width = frame.shape[:2]
            frame_anno = anno_df[anno_df['frame_id'] == frame_idx]
            cv2.imwrite(output_dir + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.jpg', frame)

            if len(frame_anno) != 0:
                with open(output_dir + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.txt', 'w') as up:
                    for category, bbox in zip(frame_anno['category'].values, frame_anno['bbox'].values):
                        category_idx = category_labels.index(category)
                        x_min, y_min, x_max, y_max = bbox
                        x_center = (x_min + x_max) / 2 / img_width
                        y_center = (y_min + y_max) / 2 / img_height
                        width = (x_max - x_min) / img_width
                        height = (y_max - y_min) / img_height
                        up.write(f'{category_idx} {x_center} {y_center} {width} {height}\n')

            frame_idx += 1

# 处理训练集
process_dataset(train_annos[:5], train_videos[:5], 'yolo-dataset/train/')
# 处理验证集
process_dataset(train_annos[-3:], train_videos[-3:], 'yolo-dataset/val/')

角度 3: 模型选择

3.1 可以尝试更多目标检测的相关模型,或者切换不同的模型预训练权重,例如:
 
图片1

来源于https://linklearner.com/activity/16/16/68
#原始
!wget http://mirror.coggle.club/yolo/yolov8n-v8.2.0.pt -O yolov8n.pt
#修改后(任选其一)
!wget http://mirror.coggle.club/yolo/yolov8s-v8.2.0.pt -O yolov8s.pt
!wget http://mirror.coggle.club/yolo/yolov8m-v8.2.0.pt -O yolov8m.pt
!wget http://mirror.coggle.club/yolo/yolov8l-v8.2.0.pt -O yolov8l.pt
!wget http://mirror.coggle.club/yolo/yolov8x-v8.2.0.pt -O yolov8x.pt

3.2 模型集成可以通过结合多个模型的预测结果来提高模型性能,如投票、平均等方式

# 代码示例
model1 = Model1()
model2 = Model2()
ensemble_predictions = (model1_predictions + model2_predictions) / 2

角度 4: 模型优化

4.1 可以尝试调整训练参数,如学习率、批量大小等,以优化模型的训练过程

# 代码示例:
results = model.train(data="yolo-dataset/yolo.yaml", epochs=5, imgsz=1080, batch=32, lr=0.001)

4.2 可以使用学习率调度器,动态调整学习率,有助于加快模型收敛速度并提高模型性能。常见的学习率调度器包括StepLR、ReduceLROnPlateau、CosineAnnealing等

# 代码示例
from torch.optim.lr_scheduler import StepLR
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

for epoch in range(num_epochs):
    scheduler.step()
    train(...)

4.3 可以使用正则化技术,即在损失函数中添加正则化项来限制模型的复杂度,有助于防止过拟合。常见的正则化方法包括L1正则化和L2正则化

# 代码示例
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

角度 5: 模型训练

5.1 在原始数据有限的情况下,可以对训练数据进行数据增强(随机变换),有助于提高模型的泛化能力。常见的数据增强方法包括随机裁剪、旋转、翻转、色彩变换等(原代码没有相关库,需要额外导入torchvision这个库)

# 代码示例
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

5.2 训练模型时可以使用交叉验证技术来评估模型性能,并选择训练最优模型,常见的交叉验证方法包括K折交叉验证和留一交叉验证

# 代码示例
from sklearn.model_selection import KFold

kf = KFold(n_splits=5)
for train_index, val_index in kf.split(data):
    train_data, val_data = data[train_index], data[val_index]
    model = train_model(train_data)
    evaluate_model(model, val_data)

5.3 可以增加更多的视频到训练集中进行训练,以增加训练数据的多样性,提高模型的泛化能力

# 代码示例
for anno_path, video_path in zip(train_annos[:10], train_videos[:10]):   #原始代码是 for anno_path, video_path in zip(train_annos[:5], train_videos[:5]):
    print(video_path)
    anno_df = pd.read_json(anno_path)
    cap = cv2.VideoCapture(video_path)
    frame_idx = 0 
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        img_height, img_width = frame.shape[:2]
        frame_anno = anno_df[anno_df['frame_id'] == frame_idx]
        cv2.imwrite('./yolo-dataset/train/' + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.jpg', frame)
        if len(frame_anno) != 0:
            with open('./yolo-dataset/train/' + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.txt', 'w') as up:
                for category, bbox in zip(frame_anno['category'].values, frame_anno['bbox'].values):
                    category_idx = category_labels.index(category)                    
                    x_min, y_min, x_max, y_max = bbox
                    x_center = (x_min + x_max) / 2 / img_width
                    y_center = (y_min + y_max) / 2 / img_height
                    width = (x_max - x_min) / img_width
                    height = (y_max - y_min) / img_height
                    if x_center > 1:
                        print(bbox)
                    up.write(f'{category_idx} {x_center} {y_center} {width} {height}\n')       
        frame_idx += 1

角度 6: 异常处理

6.1 没有对异常情况进行充分处理(如文件读取失败、视频帧读取失败等情况)进行充分处理,建议添加异常处理机制,确保代码的稳定性和健壮性

# 代码示例
try:
    # 代码逻辑
except Exception as e:
    print(f"Error: {e}")

二、进阶实践效果

1. 本次进阶

  • 参照3.1 更换了yolov8s和yolov8m预训练权重分别尝试
  • 参照4.1 更换 epochs=2(原始)epochs=5
  • 参照5.3 增加训练集数据

由于时间有限,只尝试了几个改进方案,后续会进一步优化模型并记录分享

2. 初步效果

 
F1 curve & R curve
P curve & PR curve
 
result.png


总结

只是本篇针对Task2的总结,系列总结会在第三篇

以上就是对task2任务的详细讲解,进行了进阶学习、思路拓展和初步模型优化,后续会继续深化学习,尝试更多模型优化方案,持续分享。

  • 27
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值