【算法竞赛】kaggle的Global Wheat Detection练习(基于yolov8训练及预测)


前言

kaggle上的Global Wheat Detection是一个经典的目标检测竞赛,数据集公开,目标明确,难度不大,很适合初学者作为前期项目学习的资料。


一、竞赛介绍

竞赛目标是对小麦麦穗进行检测,将麦穗头检测框选出来,是典型的目标检测问题。由于密密麻麻的小麦植物常常重叠在一起,且风的吹拂会使照片变得模糊不清,给目标检测带来一定的难题。
竞赛链接:https://www.kaggle.com/competitions/global-wheat-detection
本文用yolov8作为基准模型,快速上手yolov8的使用。

二、前期准备

1.yolov8安装

yolov8和yolov5一样都是由ultralytics开发,yolov8可以用来做目标检测,图像分割,姿态估计,使用方便,非常适合初学者体验。
安装方式:

pip install ultralytics

2.数据集准备

从kaggle下载数据集。
方法一:点击下面,画圈部分
在这里插入图片描述
方法二:
使用kaggle的api下载

pip install kaggle
kaggle competitions download -c global-wheat-detection

下载完后,打开train.csv文件夹。
![在这里插入图片描述](https://img-blog.csdnimg.cn/0109097d6ecc4732ad3540daa86c356e.png其中每一行表示一个目标,一个image_id表示一张图片,之所以id相同,表示一张图像上的多个目标,图像高宽为1024*1024,bbox就是标注好的目标坐标[xmin, ymin, width, height],(xmin,ymin)表示目标左上角坐标,width,height表示目标宽高。
将该数据转换成yolo格式:

import numpy as np
import pandas as pd
from tqdm import tqdm
import os
import shutil

def data_to_yolo(path):
    df = pd.read_csv(path)
    bboxs = np.stack(df['bbox'].apply(lambda x: np.fromstring(x[1:-1], sep=',')))
    for i, column in enumerate(['x', 'y', 'w', 'h']):
        df[column] = bboxs[:, i]
    df.drop(columns=['bbox'], inplace=True)
    df['x_center'] = df['x'] + df['w'] / 2
    df['y_center'] = df['y'] + df['h'] / 2
    df['classes'] = 0
    return df[['image_id', 'x', 'y', 'w', 'h', 'x_center', 'y_center', 'classes']]

if __name__ == '__main__':

    train_csv_path = r'F:\kaggle\Global Wheat Detection\global-wheat-detection\train.csv' # 训练表格文件所在路径
    new_dataframe = data_to_yolo(train_csv_path)
    print(new_dataframe.head())
    
    index = list(set(new_dataframe.image_id))
    
    source = 'train'
    if True:
        for fold in [0]:
            val_index = index[len(index) * fold // 5:len(index) * (fold + 1) // 5]
            for name, mini in tqdm(new_dataframe.groupby('image_id')):
                if name in val_index:
                    path2save = 'val2017/'
                else:
                    path2save = 'train2017/'
                if not os.path.exists('convertor/fold{}/labels/'.format(fold) + path2save):
                    os.makedirs('convertor/fold{}/labels/'.format(fold) + path2save)
                with open('convertor/fold{}/labels/'.format(fold) + path2save + name + ".txt", 'w+') as f:
                    row = mini[['classes', 'x_center', 'y_center', 'w', 'h']].astype(float).values
                    row = row / 1024
                    row = row.astype(str)
                    for j in range(len(row)):
                        text = ' '.join(row[j])
                        f.write(text)
                        f.write("\n")
                if not os.path.exists('convertor/fold{}/images/{}'.format(fold, path2save)):
                    os.makedirs('convertor/fold{}/images/{}'.format(fold, path2save))
                shutil.copy("global-wheat-detection/{}/{}.jpg".format(source, name),
                            'convertor/fold{}/images/{}/{}.jpg'.format(fold, path2save, name))
    
    print('Finish')

三、模型训练

要使用yolov8训练,需要配置好两个yaml文件,其中数据格式在\ultralytics-main\ultralytics\datasets中,可以直接用coco128.yaml修改成自己的图像路径即可。
在这里插入图片描述
注意要将names改成一个类别,download的路径删除。如图所示即可。path是图像所在的主路径,train和val是其下的文件夹
配置完后就可以直接训练。

import utlralytics
    # Load a model
    model = ultralytics.YOLO(r".\config\yolov8n.yaml")  # build a new model from scratch
    model = ultralytics.YOLO("yolov8n.pt")  # load a pretrained model (recommended for training)

    # Use the model
    model.train(data=r".\coco128.yaml", epochs=3, batch=8)  # train the model
    metrics = model.val()  # evaluate model performance on the validation set
    results = model("0a3cb453f.jpg")  # predict on an image

训练完成后,会在当前文件夹下生成run文件夹。里面有模型训练权重和一些训练结果图,方便后续部署。
预测时:

import ultralytics
import numpy as np
import cv2

if __name__ == '__main__':
    img_path = r".\test\53f253011.jpg"
    model = ultralytics.YOLO(r".\best.pt")
    
    results = model(img_path)
    boxes = np.array(results[0].boxes.xyxy)
    img = cv2.imread(img_path)
    for box in boxes:
        cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 255), thickness=5)
    cv2.putText(img, f'number={len(boxes)}', (img.shape[0]//2, img.shape[1]//2), cv2.FONT_HERSHEY_SIMPLEX, 2, color=(0, 0, 255),thickness=3)
    cv2.imwrite('test.jpg', img)
    print("Finish")

总结

以上介绍yolov8的使用,即使第一次接触目标检测的也能快速获得结果,不得不说ultralytics封装的真好,直接调用其API即可。注意,不能直接用来跑竞赛。

  • 1
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值