AI平台训练YOLO模型的步骤流程总结

一、挂载数据

筛选小技巧:
可以全部选择数据集,构建数据集,在详情页可以看到每类有多少张图片,不需要一类一类的去查看。

数据集的筛选很重要,注意数据是否存在编码混乱和标注错误
check一下标签是否正确:

import cv2
import os
import numpy as np
import json
import sys
import random


classes = ["0D0004", "0D0005", "0D2007", "0D200B", "0D2010", "0D201A", "0D6201", "0D7001", "0D7098","0D0000","0DOther"]

image_dir = "/root/store-data_new/crop1125/images"
outdir = "/root/store-data_new/data_process/check_label/check_json_images"
label_dir = "/root/store-data_new/crop1125/labels"
images = os.listdir(image_dir)
labels = os.listdir(label_dir)

for i in images:
    image = cv2.imread(os.path.join(image_dir,i))
    json_file = os.path.join(label_dir,(i.split(".")[0] + ".json"))
    with open(json_file) as ann_f:
        annos = json.load(ann_f)
        img_w = image.shape[1]
        img_h = image.shape[0]
        objects = annos['objects']
        for obj in objects:
            xyxy_box = [obj['obj_points'][0]['x'], obj['obj_points'][0]['y'], obj['obj_points'][0]['x'] + obj['obj_points'][0]['w'], obj['obj_points'][0]['y'] + obj['obj_points'][0]['h']]
            label = obj['f_code']
            cv2.rectangle(image, ((int)(xyxy_box[0]), (int)(xyxy_box[1])), ((int)(xyxy_box[2]), (int)(xyxy_box[3])), (0, 0, 255), 2)
            font = cv2.FONT_HERSHEY_SIMPLEX
            cv2.putText(image,label,(int(xyxy_box[0])+10, int(xyxy_box[1])+10), font, 1,(0,0,255),2)
    cv2.imwrite(os.path.join(outdir,i), image)
    print(i,"xml check successfully")

二、数据预处理

1、所有类别数据整合

1.1 检查图片和标签的后缀是否正确

执行脚本,这个脚本在AI平台被注释掉了,一般是存在图片的后缀".jpg",被误写为".JPG",遇到新的数据集则最好运行一下改脚本,以防还存在其他的错误后缀,判断后缀脚本如下:

import os
import shutil

root_dir = "/root/data1125/"
dst_dir = "/root/store-data_new/crop1125/"

# def know_finish(path1):
list_finish = []
for data_class in os.listdir(root_dir):# data_0001
    for target in os.listdir(os.path.join(root_dir,data_class)): # images labels
        for my_object in os.listdir(os.path.join(os.path.join(root_dir,data_class),target)): # xxx.jpg xxx.json
            src_path = os.path.join(os.path.join(os.path.join(root_dir,data_class),target),my_object)
            str_object = my_object.split(".")
            if str_object[-1] == "jpg" or str_object[-1] == "json":
            	continue             
            elif str_object[-1] in list_finish:
            	continue
            else:
            	list_finish.append(str_object[-1])
print(list_finish)

1.2 将所有类别整合到一起

gather.py的内容如下:

import os
import shutil

root_dir = "/root/data1125/"
dst_dir = "/root/store-data_new/crop1125/"

# def know_finish(path1):
list_finish = []
for data_class in os.listdir(root_dir):# data_0001
    for target in os.listdir(os.path.join(root_dir,data_class)): # images labels
        for my_object in os.listdir(os.path.join(os.path.join(root_dir,data_class),target)): # xxx.jpg xxx.json
            src_path = os.path.join(os.path.join(os.path.join(root_dir,data_class),target),my_object)
            str_object = my_object.split(".")
            if str_object[-1] == "jpg":
                dst_path = dst_dir + "images/"+ my_object
            elif str_object[-1] == "json":
                dst_path = dst_dir + "labels/"+ my_object
            elif str_object[-1] == "JPG":
                my_object = str_object[0] + ".jpg"
                dst_path = dst_dir + "images/"+ my_object
            shutil.copy(src_path,dst_path)
            print(dst_path,"copy successfully")

1.3 查看整合前后图片标签数量的变化

一般整合后得数目要小于整合前,因为同一张图片可能包含多个类别的目标

import os
import shutil

root_dir = "/root/data1125/"
image_dst_dir = "/root/store-data_new/crop1125/images/"
json_dst_dir = "/root/store-data_new/crop1125/labels/"

class_list = os.listdir(root_dir)
sum_num = 0
for test_class in class_list:
    second_dir = os.path.join(root_dir,test_class)
    image_dir = os.path.join(second_dir,"images")
    json_dir = os.path.join(second_dir,"labels")
    if not os.path.exists(image_dir):
        continue
    else:
        images = os.listdir(image_dir)
        jsons = os.listdir(json_dir)
        sum_num  = sum_num + len(images)
        print(test_class,"images num",len(images),"jsons num",len(jsons))
print(sum_num)
print(len(os.listdir(image_dst_dir)),len(os.listdir(json_dst_dir)))
    

2、格式转换

需要将平台的json格式转换为txt格式

2.1 json转换成VOC格式的xml

config.yaml的内容:

INPUT_IMAGES_DIR: '/root/store-data_new/crop1125/images/' 
INPUT_LABELS_DIR: '/root/store-data_new/crop1125/labels/'  
OUTPUT_DIR: '/root/store-data_new/crop1125/'  

transform2voc.py的内容:

from aicloud import json_to_voc
yamlPath = '/root/store-data_new/data_process/transform_data/config.yaml'
result = json_to_voc.json2voc(yamlPath)

2.2 xml转换成txt

注意:
1、如果后续需要裁剪,则不需要执行归一化函数
2、如果后续不需要裁剪,直接参与训练,则要归一化

import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join

sets=[('2007', 'train'), ('2007', 'val'), ('2007', 'test')]

classes = ["0D0001", "0D0004", "0D0005", "0D2007", "0D200B", "0D2010", "0D201A", "0D6201", "0D7001", "0D7098"]


def convert(size, box):
    dw = 1./(size[0])
    dh = 1./(size[1])
    x = (box[0] + box[1])/2.0 - 1
    y = (box[2] + box[3])/2.0 - 1
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x*dw
    w = w*dw
    y = y*dh
    h = h*dh
    return (x,y,w,h)

def convert_annotation(year, image_id):
    in_file = open('/root/store-data_new/crop1125/VOCdevkit/VOC%s/Annotations/%s.xml'%(year, image_id))
    out_file = open('/root/store-data_new/crop1125/VOCdevkit/VOC%s/labels/%s.txt'%(year, image_id), 'w')
    tree=ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)

    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult)==1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
        bb = convert((w,h), b)
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')

wd = getcwd()

for year, image_set in sets:
    if not os.path.exists('/root/store-data_new/crop1125/VOCdevkit/VOC%s/labels/'%(year)):
        os.makedirs('/root/store-data_new/crop1125/VOCdevkit/VOC%s/labels/'%(year))
    image_ids = open('/root/store-data_new/crop1125/VOCdevkit/VOC%s/ImageSets/Main/%s.txt'%(year, image_set)).read().strip().split()
    list_file = open('%s_%s.txt'%(year, image_set), 'w')
    for image_id in image_ids:
        # list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s.jpg\n'%(wd, year, image_id))
        convert_annotation(year, image_id)
    # list_file.close()

# os.system("cat 2007_train.txt 2007_val.txt 2012_train.txt 2012_val.txt > train.txt")
# os.system("cat 2007_train.txt 2007_val.txt 2007_test.txt 2012_train.txt 2012_val.txt > train.all.txt")

有可能会遇到xml报错,这个时候一般是因为报错的xml是空的,有可能是磁盘空间不够,解压之后xml是空的,可以尝试压缩成tar zip,扩大磁盘空间或者腾出空间

3、图片裁剪(可选)

4、check处理过的图片和txt标签

import cv2
import os
import numpy as np

classes = ["0D0004", "0D0005", "0D2007", "0D200B", "0D2010", "0D201A", "0D6201", "0D7001", "0D7098","0D0000","0DOther"]

image_dir = "/root/store-data_new/crop1125/VOCdevkit/VOC2007/JPEGImages"
outdir = "/root/store-data_new/data_process/check_label/check_images"
label_dir = "/root/store-data_new/crop1125/VOCdevkit/VOC2007/labels"
images = os.listdir(image_dir)
labels = os.listdir(label_dir)

for i in images:
    image = cv2.imread(os.path.join(image_dir,i))
    w = image.shape[1]
    h = image.shape[0]
    label_path = os.path.join(label_dir,(i.split(".")[0] + ".txt"))
    with open(label_path, 'r') as f:
        for line in f.readlines():
            line_strs = line.split()
            cls_id = int(line_strs[0])
            x = float(line_strs[1])* w
            y = float(line_strs[2]) * h
            weight = float(line_strs[3]) * w
            height = float(line_strs[4]) * h
            xmin = x - weight/2
            xmax = x + weight/2
            ymin = y - height/2
            ymax = y + height/2
            # xmin = int(float(line_strs[1]))
            # xmax = int(float(line_strs[2]))
            # ymin = int(float(line_strs[3]))
            # ymax = int(float(line_strs[4]))
            cv2.rectangle(image, (int(xmin),int(ymax)), (int(xmax), int(ymin)), (0, 0, 255), 2)
            label = classes[cls_id]
            font = cv2.FONT_HERSHEY_SIMPLEX
            cv2.putText(image,label,(int(xmin)+10, int(ymax)+10), font, 1,(0,0,255),2)

    cv2.imwrite(os.path.join(outdir,i), image)
    print(i,"check successfully")

三、训练模型

1、crop.cfg的修改

1.1 修改anchor

./darknet detector calc_anchors crop/crop.data -num_of_clusters 9 -width 608 -height 608

1.2 修改类别

同时修改filters和class,filters = 3 * (class+5),有三处

2、crop.data的修改

3、crop.names的修改

4、开始训练

./darknet detector train crop/crop.data crop/crop.cfg backup/yolov4.conv.137 -gpus 0,1 -dont_show -mjpeg_port 8092 -map
./darknet detector train crop/crop.data crop/crop.cfg backup/yolov4.conv.137 -gpus 0 -dont_show -mjpeg_port 8092 -map

5、测试模型

(1) 测试单张图片:

./darknet detector test crop/crop.data crop/crop.cfg backup/crop_best.weights test.jpg -dont_show

(2) 计算mAP:

./darknet detector map crop/crop.data crop/crop.cfg backup/v1.0/crop_best.weights 

默认是阈值0.5
计算mAP:

./darknet detector map crop/crop.data crop/crop.cfg backup/v1.0/crop_best.weights -iou_thresh 0.5 

(3) 计算recall:

./darknet detector recall crop/crop.data crop/crop.cfg backup/v1.0/crop_best.weights 

(4) 批量测试

 ./darknet detector valid crop/crop.data crop/crop.cfg backup/v1.0/crop_best.weights -dont_show -thresh 0.25
./darknet detector test crop/crop.data crop/crop.cfg crop/crop_best.weights -ext_output  -dont_show -out result.json < /root/data/empty/order/voc_2012.txt

测试结果默认保存在results文件夹内,把测试结果可视化为图片的代码如下:

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值