使用自己数据集训练yolov5-6.0

目录

1.生成数据集

 2.运行train.py脚本

3.运行val.py脚本

4. 运行detect.py脚本

最后,附杂草预测结果图


简要说明:

  • 默认yolov5-6.0已下载到本地,并相关依赖包已经安装完毕
  • yolov5-6.0代码下载地址:yolov5-6.0代码下载地址
  • 需要下载的相关依赖包在requirems.txt中,在终端输入命令:pip install -r requirements.txt,即可自动安全相关依赖包
  • python版本>=3.8,PyTorch>=1.6

1.生成数据集

概述:使用labelimg标注图片数据(pascal voc格式),如想使用离线目标检测数据增强,可参照这篇博文(模型本身就有在线数据增强),然后将pascal voc数据格式转换为coco数据格式,并划分为训练集和测试集,生成train2017.txt和val2017.txt(训练集和测试集的图片路径)

1)官网下载:labelimg官网下载

2)进入cmd命令框,输入labelimg,点击回车,就会弹出labelimg图像标注工具

3)标注图片,Open Dir选择D/data/image(改成你自己的图片所在路径),Change Save Dir选择D/data/annotations(改成所要保存的xml文件路径),因为还要将其转换为coco数据格式,因此只需要这两个文件夹

 标注完所有类别后直接点击Save,然后点击Next Image即可

4)在生成pascal voc数据集后,将数据集转为coco数据集,并生成训练集、测试集、train2017.txt和val2017.txt。 注意:

  • 该脚本还需要一个含所有类别标签的json文件
  • 为了避免出错,文件夹格式尽量于coco数据集文件格式一致

生成json文件代码和内容如下:(目标检测的类别都是从1开始,0默认表示背景图)

import json

dict_class={'weed':1,
'crop':2,
}
def json_file(dict_class):
    json_str = json.dumps(dict_class, indent=1)
    with open('./classes.json', 'w') as json_file:
        json_file.write(json_str)
json_file(dict_class)

将数据集转为coco数据集,并生成训练集、测试集、train2017.txt和val2017.txt代码如下

"""
本脚本有三个功能:
1.将voc数据集标注信息(.xml)转为yolo标注格式(.txt),包含标签和相应的box坐标
2.划分为训练集(80%)和验证集(20%)
3.生成train2017.txt和val2017.txt
"""
import os
from tqdm import tqdm
from lxml import etree
import json
import cv2


def parse_xml_to_dict(xml):
    """
    将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
    Args:
        xml: xml tree obtained by parsing XML file contents using lxml.etree

    Returns:
        Python dictionary holding XML contents.
    """

    if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
        return {xml.tag: xml.text}

    result = {}
    for child in xml:
        child_result = parse_xml_to_dict(child)  # 递归遍历标签信息
        if child.tag != 'object':
            result[child.tag] = child_result[child.tag]
        else:
            if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                result[child.tag] = []
            result[child.tag].append(child_result[child.tag])

    return {xml.tag: result}


def translate_info(file_names: list,resouce_imgPath,resouce_xmlPath,save_txt_path, class_dict: dict):
    """
    将对应xml文件信息转为yolo中使用的txt文件信息
    :param file_names:
    :param save_root:
    :param class_dict:
    :param train_val:
    :return:
    """
    if os.path.exists(save_txt_path) is False:
        os.makedirs(save_txt_path)
    if os.path.exists(resouce_imgPath) is False:
        os.makedirs(resouce_imgPath)


     # 检查下图像文件是否存在
    img_path =resouce_imgPath+'/'+file_names+'jpg'
    assert os.path.exists(img_path), "file:{} not exist...".format(img_path)

    # 检查xml文件是否存在
    xml_path = resouce_xmlPath+'/'+file_names+'xml'
    assert os.path.exists(xml_path), "file:{} not exist...".format(xml_path)

    # read xml
    with open(xml_path) as fid:
        xml_str = fid.read()
    xml = etree.fromstring(xml_str)
    data = parse_xml_to_dict(xml)["annotation"]
    img_height = int(data["size"]["height"])
    img_width = int(data["size"]["width"])

    # write object info into txt
    assert "object" in data.keys(), "file: '{}' lack of object key.".format(xml_path)
    if len(data["object"]) == 0:
        # 如果xml文件中没有目标就直接忽略该样本
        print("Warning: in '{}' xml, there are no objects.".format(xml_path))

    save_txt=save_txt_path+'\\'+file_names+'txt'
    with open(save_txt, "w+") as f:
        for index, obj in enumerate(data["object"]):
            label=obj
            # 获取每个object的box信息
            xmin = float(obj["bndbox"]["xmin"])
            xmax = float(obj["bndbox"]["xmax"])
            ymin = float(obj["bndbox"]["ymin"])
            ymax = float(obj["bndbox"]["ymax"])
            class_name = obj["name"]
            class_index = class_dict[class_name] - 1  # 目标id从0开始

            # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
            if xmax <= xmin or ymax <= ymin:
                print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
                continue

            # 将box信息转换到yolo格式
            xcenter = xmin + (xmax - xmin) / 2
            ycenter = ymin + (ymax - ymin) / 2
            w = xmax - xmin
            h = ymax - ymin

            # 绝对坐标转相对坐标,保存6位小数
            xcenter = round(xcenter / img_width, 6)
            ycenter = round(ycenter / img_height, 6)
            w = round(w / img_width, 6)
            h = round(h / img_height, 6)

            info = [str(i) for i in [class_index, xcenter, ycenter, w, h]]

            if index == 0:
                f.write(" ".join(info))
            else:
                f.write("\n" + " ".join(info))


def trainval(resouce_imgPath,resouce_xmlPath,train_rate,save_imgPath,save_txtPath):
    l=len(os.listdir(resouce_imgPath))
    print("总的长度len{}".format(l))
    vl=int(l*(1-train_rate)+1)
    k=int(l/vl)
    print("总的长度val_len{}".format(vl))
    i=0
    count=0
    # read class_indict
    json_file = open(label_json_path, 'r')
    class_dict = json.load(json_file)
    for img in os.listdir(resouce_imgPath):
        xmlName=img[:-3]
        i+=1
        s = 'train2017'
        if i%int(l/vl)==0:
            count+=1
            print("valLen{}".format(count))
            s='val2017'
            # save_xml_path=save_imgPath+'/'+s
        save_txt_path=save_txtPath+'/'+s
        # 调用保存img文件和txt文件函数
        saveImgPath=save_imgPath+'/'+s+img
        im=cv2.imread(resouce_imgPath+'/'+img)
        cv2.imwrite(save_imgPath+'/'+s+'/'+img,im)
        translate_info(xmlName, resouce_imgPath,resouce_xmlPath,save_txt_path, class_dict)

def trainval_txt(save_img_path,image_path,train_or_val):
    # image_path = './yolo_data/coco128/images/train2017/'  # 修改为自己的路径
    file_path = save_img_path + train_or_val + '.txt'
    file = open(file_path, 'w')  # 修改为自己的路径
    for filename in os.listdir(image_path):
        # print(filename)
        file.write(image_path + filename)
        file.write('\n')
    file.close()


if __name__ == "__main__":
    #先将相应pascal数据集下生成labels(coco),并划分tain和val
    label_json_path = './classes.json'
    #可以换成自己想要的划分比例
    train_rate=0.8
    label = ['train2017', 'val2017']
    #改成自己原来的img路径和xml文件路径
    resouce_imgPath='./new/new_xmls'
    resouce_xmlPath= './new/new_imgs'
    #改成所要保存的img路径和xml文件路径
    save_imgPath='./yolo_data/from_0_start/images'
    save_txtPath='./yolo_data/from_0_start/labels'
    trainval(resouce_imgPath, resouce_xmlPath, train_rate, save_imgPath, save_txtPath)
    for i in range(len(label)):
        save_img_path = './yolo_data/from_0_start/'
        image_path = './yolo_data/{}/images/{}/'.format(label[i])
        trainval_txt(save_img_path, image_path, label[i])

    #将不同数据集进行划分(没有多个数据集,就不需要用该部分)
    # file_name=['250imgs','500imgs','1000imgs']
    # for k in range(len(file_name)):
    #     resouce_imgPath='./pascal_data/{}/VOCdevkit/VOC2007/JPEGImages'.format(file_name[k])
    #     resouce_xmlPath='./pascal_data/{}/VOCdevkit/VOC2007/Annotations'.format(file_name[k])
    #     save_imgPath='./yolo_data/{}/images'.format(file_name[k])
    #     save_txtPath='./yolo_data/{}/labels'.format(file_name[k])
    #     trainval(resouce_imgPath, resouce_xmlPath, train_rate, save_imgPath, save_txtPath)
    #     for i in range(len(label)):
    #         save_img_path = './yolo_data/{}/'.format(file_name[k])
    #         image_path='./yolo_data/{}/images/{}/'.format(file_name[k],label[i])
    #         trainval_txt(save_img_path, image_path, label[i])





 2.运行train.py脚本

概述:该部分主要修改两个配置文件,模型配置文件和读取数据配置文件

1)模型配置文件,位于源码的models模块下,主要 这里以yolov5s.yaml为模板,将nc对应的种类改为自己的类别总类(我实验的类别种类是两种weed和crop,因此改为2),其他地方不变,生成yolov5s_my.yaml

2)读取数据配置文件,位于data文件目录下,生成自己的读取数据集的配置文件,这里保存为my1000imgsData.yaml

 3)修改train脚本的相关参数(主要是修改为上述生成的配置文件),如下,同时下载预训练权重yolov5s.pt放到weights文件夹下(手动下载更快)。预训练权重下载地址(选择yolov5s.pt)

4)上述准备完毕,就可以运行train.py脚本了,结果存在runs/train/exp(i)中

3.运行val.py脚本

1)修改val.py的读取数据配置文件和权重文件,如下:

 2)这里还需要修改模型文件,因为yolov5使用coco数据集训练模型,其种类有80种,而我的只有2种,因此需要找到相应的模型文件进行修改。(如果疑惑为什么train.py不用修改,而val.py需要修改:主要原因是在train.py中就修改了模型的配置文件,不信你往上翻,会找到yolov5s_my.yaml。而val.py脚本内没有导入模型配置文件,因此需要在原模型上修改)

3)上述步骤完成后,就可以运行val.py脚本了。如果运行不了,就是上述步骤修改出错了,多检查一下。上述工作完成后,即可运行val.py文件,生成的结果保存在runs/val/exp(i)中

4. 运行detect.py脚本

1)将需要预测的图片放入data/images文件夹中,并修改detect.py内的权重文件,如下

 

 2)上述工作完成后,即可运行detect.py文件,生成的结果保存在runs/detect/exp(i)中

最后,附杂草预测结果图

  • 7
    点赞
  • 56
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小岑要努力

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值