1.先到github上下载,ultralytics源代码
2.pycharm新建一个项目
3.准备训练数据
数据的结构如下
不需要.yaml文件,代码会自动识别要分的类
4.创建一个训练文件
import torch
import random
import cv2
import numpy as np
import os
from ultralytics import YOLO
def TrainData():
model = YOLO('D:\\Source\\SourceMe\\PythonProject\\TrainClassificationPill\\TrainClassificationPill\\yolov8x-cls.pt')
'''这里把amp设置成False,不然训练的时候回去网上默认下载预处理权重'''
results = model.train(data='D:\\Source\\SourceMe\\PythonProject\\TrainClassificationPill\\TrainClassificationPill\\Dataset',
epochs=200,
batch=4,
amp=False)
sucess = model.export(format='onnx')
print(results)
def TestModelUltralytics():
model = YOLO("加载训练的.pt文件")
img=cv2.imread("要检测的图片")
yolo_classes=list(model.names.values())
classes_ids=[yolo_classes.index(clas) for clas in yolo_classes]
conf= 0.2
results=model.predict(img,conf=conf)
pass
if __name__ == '__main__':
TrainData()
pass