前言
kaggle上的Global Wheat Detection是一个经典的目标检测竞赛,数据集公开,目标明确,难度不大,很适合初学者作为前期项目学习的资料。
一、竞赛介绍
竞赛目标是对小麦麦穗进行检测,将麦穗头检测框选出来,是典型的目标检测问题。由于密密麻麻的小麦植物常常重叠在一起,且风的吹拂会使照片变得模糊不清,给目标检测带来一定的难题。
竞赛链接:https://www.kaggle.com/competitions/global-wheat-detection
本文用yolov8作为基准模型,快速上手yolov8的使用。
二、前期准备
1.yolov8安装
yolov8和yolov5一样都是由ultralytics开发,yolov8可以用来做目标检测,图像分割,姿态估计,使用方便,非常适合初学者体验。
安装方式:
pip install ultralytics
2.数据集准备
从kaggle下载数据集。
方法一:点击下面,画圈部分
方法二:
使用kaggle的api下载
pip install kaggle
kaggle competitions download -c global-wheat-detection
下载完后,打开train.csv文件夹。
其中每一行表示一个目标,一个image_id表示一张图片,之所以id相同,表示一张图像上的多个目标,图像高宽为1024*1024,bbox就是标注好的目标坐标[xmin, ymin, width, height],(xmin,ymin)表示目标左上角坐标,width,height表示目标宽高。
将该数据转换成yolo格式:
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
import shutil
def data_to_yolo(path):
df = pd.read_csv(path)
bboxs = np.stack(df['bbox'].apply(lambda x: np.fromstring(x[1:-1], sep=',')))
for i, column in enumerate(['x', 'y', 'w', 'h']):
df[column] = bboxs[:, i]
df.drop(columns=['bbox'], inplace=True)
df['x_center'] = df['x'] + df['w'] / 2
df['y_center'] = df['y'] + df['h'] / 2
df['classes'] = 0
return df[['image_id', 'x', 'y', 'w', 'h', 'x_center', 'y_center', 'classes']]
if __name__ == '__main__':
train_csv_path = r'F:\kaggle\Global Wheat Detection\global-wheat-detection\train.csv' # 训练表格文件所在路径
new_dataframe = data_to_yolo(train_csv_path)
print(new_dataframe.head())
index = list(set(new_dataframe.image_id))
source = 'train'
if True:
for fold in [0]:
val_index = index[len(index) * fold // 5:len(index) * (fold + 1) // 5]
for name, mini in tqdm(new_dataframe.groupby('image_id')):
if name in val_index:
path2save = 'val2017/'
else:
path2save = 'train2017/'
if not os.path.exists('convertor/fold{}/labels/'.format(fold) + path2save):
os.makedirs('convertor/fold{}/labels/'.format(fold) + path2save)
with open('convertor/fold{}/labels/'.format(fold) + path2save + name + ".txt", 'w+') as f:
row = mini[['classes', 'x_center', 'y_center', 'w', 'h']].astype(float).values
row = row / 1024
row = row.astype(str)
for j in range(len(row)):
text = ' '.join(row[j])
f.write(text)
f.write("\n")
if not os.path.exists('convertor/fold{}/images/{}'.format(fold, path2save)):
os.makedirs('convertor/fold{}/images/{}'.format(fold, path2save))
shutil.copy("global-wheat-detection/{}/{}.jpg".format(source, name),
'convertor/fold{}/images/{}/{}.jpg'.format(fold, path2save, name))
print('Finish')
三、模型训练
要使用yolov8训练,需要配置好两个yaml文件,其中数据格式在\ultralytics-main\ultralytics\datasets中,可以直接用coco128.yaml修改成自己的图像路径即可。
注意要将names改成一个类别,download的路径删除。如图所示即可。path是图像所在的主路径,train和val是其下的文件夹
配置完后就可以直接训练。
import utlralytics
# Load a model
model = ultralytics.YOLO(r".\config\yolov8n.yaml") # build a new model from scratch
model = ultralytics.YOLO("yolov8n.pt") # load a pretrained model (recommended for training)
# Use the model
model.train(data=r".\coco128.yaml", epochs=3, batch=8) # train the model
metrics = model.val() # evaluate model performance on the validation set
results = model("0a3cb453f.jpg") # predict on an image
训练完成后,会在当前文件夹下生成run文件夹。里面有模型训练权重和一些训练结果图,方便后续部署。
预测时:
import ultralytics
import numpy as np
import cv2
if __name__ == '__main__':
img_path = r".\test\53f253011.jpg"
model = ultralytics.YOLO(r".\best.pt")
results = model(img_path)
boxes = np.array(results[0].boxes.xyxy)
img = cv2.imread(img_path)
for box in boxes:
cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 255), thickness=5)
cv2.putText(img, f'number={len(boxes)}', (img.shape[0]//2, img.shape[1]//2), cv2.FONT_HERSHEY_SIMPLEX, 2, color=(0, 0, 255),thickness=3)
cv2.imwrite('test.jpg', img)
print("Finish")
总结
以上介绍yolov8的使用,即使第一次接触目标检测的也能快速获得结果,不得不说ultralytics封装的真好,直接调用其API即可。注意,不能直接用来跑竞赛。