YOLOv8: OBB 旋转模型如何训练自己的模型(记录)
一、更新升级ultralytics模块(已经是最新的,略过即可)
pip install - - upgrade ultralytics - i https: // pypi. tuna. tsinghua. edu. cn/ simple/
二、数据采集及制作
1、使用rolabelimg工具进行标注(安装下载,百度搜索即可)
2、xml文件转换成txt文件
import os
import xml. etree. ElementTree as ET
import math
cls_list= [ '分类名称' ]
def edit_xml ( xml_file, dotaxml_file) :
"""
修改xml文件
:param xml_file:xml文件的路径
:return:
"""
tree = ET. parse( xml_file)
objs = tree. findall( 'object' )
for ix, obj in enumerate ( objs) :
x0 = ET. Element( "x0" )
y0 = ET. Element( "y0" )
x1 = ET. Element( "x1" )
y1 = ET. Element( "y1" )
x2 = ET. Element( "x2" )
y2 = ET. Element( "y2" )
x3 = ET. Element( "x3" )
y3 = ET. Element( "y3" )
if ( obj. find( 'robndbox' ) == None ) :
obj_bnd = obj. find( 'bndbox' )
obj_xmin = obj_bnd. find( 'xmin' )
obj_ymin = obj_bnd. find( 'ymin' )
obj_xmax = obj_bnd. find( 'xmax' )
obj_ymax = obj_bnd. find( 'ymax' )
xmin = max ( float ( obj_xmin. text) , 0 )
ymin = max ( float ( obj_ymin. text) , 0 )
xmax = max ( float ( obj_xmax. text) , 0 )
ymax = max ( float ( obj_ymax. text) , 0 )
obj_bnd. remove( obj_xmin)
obj_bnd. remove( obj_ymin)
obj_bnd. remove( obj_xmax)
obj_bnd. remove( obj_ymax)
x0. text = str ( xmin)
y0. text = str ( ymax)
x1. text = str ( xmax)
y1. text = str ( ymax)
x2. text = str ( xmax)
y2. text = str ( ymin)
x3. text = str ( xmin)
y3. text = str ( ymin)
else :
obj_bnd = obj. find( 'robndbox' )
obj_bnd. tag = 'bndbox'
obj_cx = obj_bnd. find( 'cx' )
obj_cy = obj_bnd. find( 'cy' )
obj_w = obj_bnd. find( 'w' )
obj_h = obj_bnd. find( 'h' )
obj_angle = obj_bnd. find( 'angle' )
cx = float ( obj_cx. text)
cy = float ( obj_cy. text)
w = float ( obj_w. text)
h = float ( obj_h. text)
angle = float ( obj_angle. text)
obj_bnd. remove( obj_cx)
obj_bnd. remove( obj_cy)
obj_bnd. remove( obj_w)
obj_bnd. remove( obj_h)
obj_bnd. remove( obj_angle)
x0. text, y0. text = rotatePoint( cx, cy, cx - w / 2 , cy - h / 2 , - angle)
x1. text, y1. text = rotatePoint( cx, cy, cx + w / 2 , cy - h / 2 , - angle)
x2. text, y2. text = rotatePoint( cx, cy, cx + w / 2 , cy + h / 2 , - angle)
x3. text, y3. text = rotatePoint( cx, cy, cx - w / 2 , cy + h / 2 , - angle)
obj_bnd. append( x0)
obj_bnd. append( y0)
obj_bnd. append( x1)
obj_bnd. append( y1)
obj_bnd. append( x2)
obj_bnd. append( y2)
obj_bnd. append( x3)
obj_bnd. append( y3)
tree. write( dotaxml_file, method= 'xml' , encoding= 'utf-8' )
def rotatePoint ( xc, yc, xp, yp, theta) :
xoff = xp - xc;
yoff = yp - yc;
cosTheta = math. cos( theta)
sinTheta = math. sin( theta)
pResx = cosTheta * xoff + sinTheta * yoff
pResy = - sinTheta * xoff + cosTheta * yoff
return str ( int ( xc + pResx) ) , str ( int ( yc + pResy) )
def totxt ( xml_path, out_path) :
files = os. listdir( xml_path)
i= 0
for file in files:
tree = ET. parse( xml_path + os. sep + file )
root = tree. getroot( )
name = file . split( '.' ) [ 0 ]
output = out_path + '\\' + name + '.txt'
file = open ( output, 'w' )
i+= 1
objs = tree. findall( 'object' )
for obj in objs:
cls = obj. find( 'name' ) . text
box = obj. find( 'bndbox' )
x0 = int ( float ( box. find( 'x0' ) . text) )
y0 = int ( float ( box. find( 'y0' ) . text) )
x1 = int ( float ( box. find( 'x1' ) . text) )
y1 = int ( float ( box. find( 'y1' ) . text) )
x2 = int ( float ( box. find( 'x2' ) . text) )
y2 = int ( float ( box. find( 'y2' ) . text) )
x3 = int ( float ( box. find( 'x3' ) . text) )
y3 = int ( float ( box. find( 'y3' ) . text) )
if x0< 0 :
x0= 0
if x1< 0 :
x1= 0
if x2< 0 :
x2= 0
if x3< 0 :
x3= 0
if y0< 0 :
y0= 0
if y1< 0 :
y1= 0
if y2< 0 :
y2= 0
if y3< 0 :
y3= 0
for cls_index, cls_name in enumerate ( cls_list) :
if cls== cls_name:
file . write( "{} {} {} {} {} {} {} {} {} {}\n" . format ( x0, y0, x1, y1, x2, y2, x3, y3, cls, cls_index) )
file . close( )
print ( i)
if __name__ == '__main__' :
roxml_path = r"通过rolabelimgvia标注完成的xml文件夹"
dotaxml_path = r'输出dota能识别的xml文件夹'
out_path = r'输出txt文件夹'
filelist = os. listdir( roxml_path)
for file in filelist:
edit_xml( os. path. join( roxml_path, file ) , os. path. join( dotaxml_path, file ) )
totxt( dotaxml_path, out_path)
3、创建训练集数据(格式如下)
images/train文件夹中放入模型训练所需的img图片;images/val文件夹中放入模型验证所需的im图片;比例一般按3:1进行配置即可。labels/train_original文件夹放入模型训练对应训练图片的txt文件;labels/val_original文件夹放入模型训练对应验证图片的txt文件。
4、将 DOTA 数据集格式转换为YOLO OBB 格式(执行下面转化代码即可自动转换)
from ultralytics. data. converter import convert_dota_to_yolo_obb
convert_dota_to_yolo_obb( './yolo_obb' )
运行转换代码前注意事项:
1、需要进入convert_dota_to_yolo_obb中修改 class_mapping = {}
2、根据自己训练集图片格式进行修改(.jpg、.png)
5、创建自己的my_obb_data.yaml文件,创建新的yolov8-obb.yaml文件,并进行修改
1、创建自己的my_obb_data.yaml文件
2、创建新的yolov8-obb.yaml文件,并进行修改
6、进行训练操作(执行下面代码,开始进行模型训练)
from ultralytics import YOLO
def main ( ) :
model = YOLO( '创建修改的yolov8-obb.yaml文件路径' ) . load( '下载好的yolov8n-obb.pt权重文件路径' )
model. train( data= '创建的my_obb_data.yaml文件路径' , epochs= 100 , imgsz= 640 )
if __name__ == '__main__' :
main( )
这里一共就使用了23张图进行了训练,标注也没有按一定规律进行标注;所有效果不是很好,前期制作标注很重要,需要提高数据集质量及数量。
总结:
1、获取数据集图片,按一定规律进行标注
2、按一定格式创建数据集(xml文件转换txt文件)
3、将 DOTA 数据集格式转换为YOLO OBB 格式
4、创建修改2个 xxx.yaml 文件
5、进行模型训练
以上就是yolov8 obb旋转模型训练自己的模型过程,只是记录一下自己使用过程,大佬勿喷。