YOLO训练数据处理工具

YOLO训练数据处理工具

为了方便网络的训练,写了一些基于python和opencv库的数据处理工具。

1.读取视频并保存每一帧

程序来源:https://blog.csdn.net/qq_43569111/article/details/103313154?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.nonecase&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.nonecase

import cv2
import glob
import os
from datetime import datetime

def video_to_frames(path):
    """
    输入:path(视频文件的路径)
    """
    # VideoCapture视频读取类
    videoCapture = cv2.VideoCapture()
    videoCapture.open(path)
    # 帧率
    fps = videoCapture.get(cv2.CAP_PROP_FPS)
    # 总帧数
    frames = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
    print("fps=", int(fps), "frames=", int(frames))

    for i in range(int(frames)):
        ret, frame = videoCapture.read()
        cv2.imwrite("frames\\frames%d.jpg" % (i), frame)
    return

if __name__ == '__main__':
    t1 = datetime.now()
    video_to_frames("test.avi")
    t2 = datetime.now()
    print("Time cost = ", (t2 - t1))
    print("SUCCEED !!!")

2.批量按顺序重命名图片数据

import os

pre_num = 0 #已经重命名的图片数量
path = "C:\\Users\\admin\\Desktop\\Fire_images\\new_image" #文件夹路径
filelist = os.listdir(path) #该文件夹下所有的文件(包括文件夹)
count=1 + pre_num

for file in filelist:
    print(file)
for file in filelist:   #遍历所有文件
    Olddir=os.path.join(path,file)   #原来的文件路径
    if os.path.isdir(Olddir):   #如果是文件夹则跳过
        continue
    filename=os.path.splitext(file)[0]   #文件名
    filetype=os.path.splitext(file)[1]   #文件扩展名
    Newdir=os.path.join(path,str(count).zfill(6)+filetype)  #os.path.join路径拼接,并用字符串函数zfill 以0补全所需位数,例:000123.jpg
    os.rename(Olddir,Newdir)#重命名
    count+=1

3.批量输出图片绝对路径

输出绝对路径至txt文件后,可以通过查找和替换,生成程序加载训练图像的相对路径。

import os.path
import glob
import os

if __name__ == "__main__":       #主函数
    realpath = os.path.realpath(__file__)       #获取当前执行脚本的绝对路径
    dirname = os.path.dirname(realpath)         #去掉文件名,返回目录(realpath的)
    extension = 'jpg'                           #寻找文件类型:jpg
    file_list = glob.glob('*.'+extension)       #glob.glob 获取当前工作目录下所有.jpg结尾的文件名称,返回一个列表。
    filetxt = open(os.path.join(dirname, 'route.txt'), 'w')  
    for index, filename in enumerate(file_list): 
        str_index = str(index)     
        filepath = os.path.join(dirname, filename)     # 记录绝对路径
        filetxt.write('%s\n'%(filepath))               
    filetxt.close()                 #循环结束,关闭文件。

4.xml文件转换为yolo格式txt文件

yolo格式的标注为<object-class> <x center> <y center> <width> <height>。

import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join

sets=[('myData', 'train')]
classes = ['fire'] # each category's name

def convert(size, box):
    dw = 1./(size[0])
    dh = 1./(size[1])
    x = (box[0] + box[1])/2.0 - 1
    y = (box[2] + box[3])/2.0 - 1
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = round((x*dw),6) #round()保留小数点后6位
    w = round((w*dw),6)
    y = round((y*dh),6)
    h = round((h*dh),6)
    return (x,y,w,h)

def convert_annotation(year, image_id):
    in_file = open('myData/Annotations/%s.xml'%(image_id))
    out_file = open('myData/labels/%s.txt'%(image_id), 'w')
    
  	tree=ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)

    for obj in root.iter('object'):
        difficult = obj.find('difficult').text #目标被"截断"程度(目标一部分在图像外)
        cls = obj.find('name').text
        if cls not in classes or int(difficult)==1:
            continue

        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
        bb = convert((w,h), b)
        
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')

wd = getcwd() #返回当前工作目录
for year, image_set in sets:
    if not os.path.exists('myData/labels/'):
        os.makedirs('myData/labels/')
    image_ids = open('myData/ImageSets/Main/%s.txt'(image_set)).read().strip().split()
    list_file = open('myData/%s_%s.txt'%(year, image_set), 'w')

    for image_id in image_ids:
        list_file.write('%s/myData/JPEGImages/%s.jpg\n'%(wd, image_id))
        convert_annotation(year, image_id)
    list_file.close()

5.随机旋转图像

程序来源:https://www.ctolib.com/topics/44419.html

import os
import cv2 as cv
import numpy as np
import math
import random

root_dir = "C:/Users/admin/Desktop/data"
output_dir = "C:/Users/admin/Desktop/output/rotate"

#rotate
def ImgRotate(image, rotate_angle = 5, scale = 1.0):
    w = image.shape[1]
    h = image.shape[0]
    # transform angle to rad
    rangle = np.deg2rad(rotate_angle)
    #calculate rotated rectangle's width and height 
    nw = abs(w * np.cos(rangle)) + abs(h * np.sin(rangle))
    nh = abs(w * np.sin(rangle)) + abs(h * np.cos(rangle))
    rot_mat = cv.getRotationMatrix2D((nw*0.5, nh*0.5), angle, scale) # return 2x3 matrix
    
    # rot_move = np.dot(rot_mat, np.array[(nw - w)*0.5, (nh - h)*0.5, 0])    
    rot_img = cv.warpAffine(image, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv.INTER_LANCZOS4)
    return rot_img

for parent, dirnames, filenames in os.walk(root_dir):
    for filename in filenames:
        if filename.split('.')[1] == 'bmp':
            name = root_dir + '/'+ str(filename)
            image = cv.imread(name)
            angle = random.randint(0, 180)
            rot_img = ImgRotate(image, angle)
            cv.imwrite(output_dir + '/' + str(filename), rot_img)

6.裁剪图像(非随机,默认object位于图像中央)

# -*- coding: utf-8 -*-
"""
Created on Sun Sep  6 09:36:20 2020

@author: admin
"""
import os
import os.path
import cv2 as cv
import random

root_dir = "C:/Users/admin/Desktop/data"
output_dir = "C:/Users/admin/Desktop/output/cut"    

def ImgCut(image, cut_scale = 3.0):
    w = image.shape[1]
    h = image.shape[0]
    x_cut = int(w / cut_scale)
    y_cut = int(h / cut_scale)
    dst = image[x_cut : w-x_cut, y_cut : h-y_cut]
    return dst

for parent, dirnames, filenames in os.walk(root_dir):
    for filename in filenames:
        if filename.split('.')[1] == 'bmp':
            name = root_dir + '/'+ str(filename) #filename shouldn't include Chinese
            print(name + '\n')
            img = cv.imread(name)
            scale = random.uniform(2.5, 5)
            result = ImgCut(img, scale)
            # cv.imshow('result',result)
            cv.imwrite(output_dir + '/' + str(filename), result)

# cv.waitKey(0)
# cv.destroyAllWindows()     

7.替换class索引

有的时候需要增加或删除种类,class的索引会发生变化。

#replace 1 with 0 in txt
import os

rootdir = "C:/Users/admin/Desktop/darknet/data/obj"

for parent, dirnames, filenames in os.walk(rootdir):
    for filename in filenames:
        if filename.split('.')[1] == 'txt':
            with open(rootdir + '/' + str(filename),'r+') as file_object:
                lines = file_object.readlines()
                line = lines[0]
                if line[0] == '1':
                    #can't edit string directly
                    print(line)
                    line = '0' + line[1:]
                    with open(rootdir + '/' + str(filename),'w') as file_object:
                        file_object.write(line)
©️2020 CSDN 皮肤主题: 编程工作室 设计师:CSDN官方博客 返回首页