【记录】使用YOLOV3训练并测试自己的数据集

labelme制作数据集

  • 首先安装labelme工具
pip install pyqt5  
pip install labelme
  • 运行labelme
labelme
  • 会出现下图所示窗口,labeme
  • 点击Open Dir 打开存放数据集的文件夹。
  • 点击Edit->Create-Rectangle开始做框
  • 标注label,点击ok保存
    在这里插入图片描述

文件夹下会出现对应图片.json文件标注信息
在这里插入图片描述
接下来用以下代码生成数据集对应的train.txt文件和label.txt文件

# coding:utf-8
# 生成darknet需要的标签

import os
import os.path as path
import json

# 保存所有类别对应的id的字典
class_id_dict = {}

# 判断是不是图片
def isPic(basename):
    file_type = basename.split('.')[-1]
    pic_file_list = ['png', 'jpg', 'jpeg', 'BMP', 'JPEG', 'JPG', 'JPeG', 'Jpeg', 'PNG', 'TIF', 'bmp', 'tif']
    if file_type in pic_file_list:
        return True
    return False

# 判断这个图片有没有对应的json文件
def has_json(img_file):
    # 得到json文件名
    base_name = path.basename(img_file)
    dir_name = img_file[:len(img_file) - len(base_name)]
    json_name = base_name.split('.')[0]
    json_name = json_name + '.json'
    json_name = path.join(dir_name, json_name)
    if path.isfile(json_name):
        return json_name
    return None

# 生成file_name的label文件,并重新写入 content_list 中内容
def rewrite_labels_file(file_name, content_list):
    with open(file_name, 'w') as f:
        for line in content_list:
            curr_line_str = ''
            for element in line:
                curr_line_str += str(element) + ' '
            f.write(curr_line_str + '\n')
    return

# 生成file_name的训练图片路径文件
def rewrite_train_name_file(file_name, content_list):
    with open(file_name, 'w') as f:
        for line in content_list:
            f.write(str(line) + '\n')
    return 

# 读取文件
def read_file(file_name):
    if not path.exists(file_name):
        print("warning:不存在文件"+str(file_name))
        return None
    with open(file_name, 'r', encoding='utf-8') as f:
        result = []
        for line in f.readlines():
            result.append(line.strip('\n'))
        return result

# 加载class_id
def load_class_id(class_name_file):
    global class_id_dict
    class_list = read_file(class_name_file)
    for i in range(len(class_list)):
        class_id_dict[str(class_list[i])] = i
    return class_id_dict

# 得到分类的id,未分类是-1
def get_id(class_name, class_name_file):
    global class_id_dict
    if len(class_id_dict) < 1:
        class_id_dict = load_class_id(class_name_file)
        print("分类 id 加载完成")
    # 补丁:替换掉汉字 "局""段"
    class_name = get_id_patch(class_name)
    if class_name in class_id_dict.keys():
        return class_id_dict[class_name]
    return -1
# 去掉汉字'段'和'局'
def get_id_patch(class_name):
    if class_name.strip() == '段':
        return 'duan'
    if class_name.strip() == '局':
        return 'ju'
    return class_name

# 解析一个points,得到坐标序列
def get_relative_point(img_width, img_height, point_list):
    # point_list是一个包含两个坐标的list
    x_min = min(point_list[0][0], point_list[1][0])
    y_min = min(point_list[0][1], point_list[1][1])
    x_max = max(point_list[0][0], point_list[1][0])
    y_max = max(point_list[0][1], point_list[1][1])
    dw = 1.0 / img_width
    dh = 1.0/ img_height
    # 中心坐标
    x = (x_min + x_max)/2.0
    y = (y_min + y_max)/2.0
    w = x_max - x_min
    h = y_max - y_min
    x = x*dw
    w = w*dw
    y = y*dh
    h = h*dh
    return [x, y, w, h]

# 解析json文件
def paras_json(json_file, class_name_file):
    if not path.exists(json_file):
        print("warning:不存在json文件" + str(json_file))
        assert(0)
    # 读取json文件拿到基本信息, encoding要注意一下
    try:
        f = open(json_file, encoding="gbk")
        setting = json.load(f)
    except:
        f = open(json_file, encoding='utf-8')
        setting = json.load(f)
    f.close()
    shapes = setting['shapes']          # 框
    height = setting['imageHeight']
    width = setting['imageWidth']
    # 拿到标签坐标
    result = []
    for shape in shapes:
        class_name = shape['label'] # 得到分类名
        # 没有这个分类就不要
        class_id = get_id(class_name, class_name_file)
        if class_id < 0:
            continue
        locate_result = get_relative_point(width, height, shape['points'])
        # 插入id
        locate_result.insert(0, class_id)
        result.append(locate_result)
    return result

# 得到文件夹下所有的图片文件
def get_pic_file_from_dir(dir_name):
    '''
        return:所有的图片文件名
    '''
    if not path.isdir(dir_name):
        print("warning:路径 %s 不是文件夹" %dir_name)
        return []
    result = []
    for f in os.listdir(dir_name):
        curr_file = path.join(dir_name, f)
        if not path.isfile(curr_file):
            continue
        if not isPic(curr_file):
            continue
        result.append(f)
    return result

def main(class_name='train.names', img_dir='JPEGImages', train_txt='train.txt', labels_dir='labels'):
    cwd = os.getcwd()
    img_dir = path.join(cwd, img_dir)
    labels_dir = path.join(cwd, labels_dir)
    if not path.exists(img_dir):
        print("error:没有发现图片文件夹 ", img_dir)
    if not path.exists(labels_dir):
        os.mkdir(labels_dir)
    
    count = 0                                                   # 进度条
    dir_len = len(os.listdir(img_dir))  # 进度条

    imgs = []
    for f in os.listdir(img_dir): 
        curr_path = path.join(img_dir, f)
        if not path.isdir(curr_path):   # 不是文件夹就先跳过
            continue
        curr_train_dir = curr_path
        # 是文件夹就创建labels对应的文件夹
        curr_labels_dir = path.join(labels_dir, f)
        if not path.isdir(curr_labels_dir):
            os.mkdir(curr_labels_dir)
        # 拿到文件夹下所有的图片文件
        curr_dir_imgs = get_pic_file_from_dir(curr_train_dir)
        # 解析这些图片的json文件
        for img_file in curr_dir_imgs:
            curr_img_file = path.join(curr_train_dir, img_file)
            json_file = has_json(curr_img_file)
            if json_file:
                # 保存图片路径
                imgs.append(curr_img_file)
                # 得到json信息 list
                json_inf = paras_json(json_file, class_name)
                # 标签文件名
                label_name = img_file.split('/')[-1].split('.')[0] + '.txt'
                curr_labels_file = path.join(curr_labels_dir, label_name)
                # 写入标签
                rewrite_labels_file(curr_labels_file, json_inf)
        count += 1
        print("\r当前进度: {:02f} %".format(count/dir_len * 100.0), end='')
    print("\n 保存训练图片路径到: ", train_txt)
    rewrite_train_name_file(train_txt, imgs)
    return 

if __name__ == "__main__":
    main()

注意修改其中的img_dir参数

自己的数据集就制作完成了!下面进入正题!

PyTorch-YOLOv3

项目来自: https://github.com/eriklindernoren/PyTorch-YOLOv3.

  • 首先是环境安装,根据requirement.txt安装依赖的包。
  • 然后进行一些配置的修改,首先运行以下命令来创建yolov3-custom.cfg
$ cd config/                                # Navigate to config dir
$ bash create_custom_model.sh <num-classes> # Will create custom model 'yolov3-custom.cfg'
  • data/custom/classes.names文件中添加label类别的名字
  • 把数据集的图片放入文件夹data/custom/images/
  • 把label放入文件夹 data/custom/labels/
  • data/custom/train.txtdata/custom/valid.txt添加用于训练集和验证集的图片的路径。(具体可以查看项目中的readme.md)
    大功告成,接下来就可以训练了
$ python train.py --model_def config/yolov3-custom.cfg --data_config config/custom.data

遇到的问题

不遂人愿,我在训练时遇到了以下两个问题,在翻查了issues后得到了解决:

Traceback (most recent call last):
File "train.py", line 176, in 
ap_table += [[c, class_names[c], "%.5f" % AP[i]]]
IndexError: list index out of range
  • 解决方法:在class.names最后一行加个回车
true_positives, pred_scores, pred_labels = [np.concatenate(x, 0) for x in list(zip(*sample_metrics))]
ValueError: not enough values to unpack (expected 3, got 0)
  • 解决方法:将代码中150行的0改为1
if epoch % opt.evaluation_interval == 0:
print("\n---- Evaluating Model ----")
change to
if epoch % opt.evaluation_interval == 1:
print("\n---- Evaluating Model ----")

ok,现在可以开始训练了,训练完后可以detect试试效果,注意修改命令行参数。
参考:
[1] https://blog.csdn.net/u014061630/article/details/88756644
[2] https://github.com/eriklindernoren/PyTorch-YOLOv3
[3]:https://github.com/rrddcc/YOLOv4_tensorflow

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值