create_train_data_yolov5

from os import listdir, system
from os.path import isfile, isdir, join
import random, yaml
import sys, cv2, os
from shutil import copyfile
from distutils.dir_util import copy_tree
from xml.dom import minidom
only_val = False
where = 'sh1'

if only_val:
    data_from = [where]
else:
    # data_from = ['sh1', 'ms_showroom', 'ms_1']
    data_from = ['original']

train_rate = 0.9
y = open("../data/voc1.yaml")
cfg = yaml.load(y, Loader=yaml.FullLoader)
labels = cfg["names"]


def create_dir_not_exist(path):
    if not os.path.exists(path):
        os.mkdir(path)

def get_position(box, img_width, img_height):
    y = int(float(box.getElementsByTagName('ymin')[0].firstChild.data))
    x = int(float(box.getElementsByTagName('xmin')[0].firstChild.data))
    xmax = int(float(box.getElementsByTagName('xmax')[0].firstChild.data))
    ymax = int(float(box.getElementsByTagName('ymax')[0].firstChild.data))
    width = int(xmax) - int(x)
    height = int(ymax) - int(y)
    center_x = x + int((int(xmax) - int(x))/2)
    center_y = y + int((int(ymax) - int(y))/2)

    return {'center_y': center_y/img_height, 'center_x': center_x/img_width, 'width': width/img_width, 'height': height/img_height}

def change_xml2txt(anno_path, xml_name, save_path):
    xml = minidom.parse(anno_path + "/" + xml_name)
    objs = xml.getElementsByTagName('object')
    size = xml.getElementsByTagName('size')[0]
    img_width = int(size.getElementsByTagName('width')[0].firstChild.data)
    img_height = int(size.getElementsByTagName('height')[0].firstChild.data)
    
    # if len(objs) <= 5:
    #     return False

    txt = open("{}/{}".format(save_path, xml_name.replace('.xml','.txt')),'w')
    
    for obj in objs:
        box = obj.getElementsByTagName('bndbox')[0]
        position = get_position(box, img_width, img_height)
        name = obj.getElementsByTagName('name')[0].firstChild.data
        try:
            index = labels.index(name)
            string = "{} {} {} {} {}\n".format(index, position['center_x'], position['center_y'], position['width'], position['height'])
            txt.write(string)
        except:
            pass
    txt.close()
    return True

def move_data_to_ready_data():
    for data in data_from:
        img_path = "../data/{}/images".format(data)
        # print("img_path:",img_path)
        anno_path = "../data/{}/annotations".format(data)
        img_names = listdir(img_path)
        xmls = listdir(anno_path)
        train_amount = int(len(img_names)*train_rate)
        random.shuffle(img_names)
        if only_val:
            train_amount = 0
        print("==================train data=======================")
        for file_name in img_names[:train_amount]:
            print(file_name)
            # xml = file_name.replace("jpg", "xml")
            if file_name.split(".")[1] == 'png':
                xml =  file_name.replace("png", "xml")
            else:
                xml = file_name.replace("jpg", "xml")
            if xml in xmls:
                # print("xml:", xml)
                # print("anno_path:",anno_path)
                # new_labels_train = "../../VOC/labels/train"
                # new_images_train = "../../VOC/images/train"
                # create_dir_not_exist(new_labels_train)
                # create_dir_not_exist(new_images_train)

                if change_xml2txt(anno_path, xml, "../VOC/labels/train"):
                    # img = cv2.imread("{}/{}".format(img_path, file_name))
                    # y, x, _ = img.shape
                    # img = cv2.resize(img, (int(x/3), int(y/3)))
                    # cv2.imwrite("./data/train/{}".format(file_name), img)
                    # print("file_name:", file_name)
                    copyfile("{}/{}".format(img_path, file_name), "../VOC/images/train/{}".format(file_name))
                else:
                    print("------------false----------------")
                
        print("==================val data=======================")
        for file_name in img_names[train_amount:]:
            print(file_name)
            # xml = file_name.replace("jpg", "xml")
            # xml =  file_name.replace("jpg", "xml")
            if file_name.split(".")[1] == 'png':
                xml =  file_name.replace("png", "xml")
            else:
                xml = file_name.replace("jpg", "xml")

            if xml in xmls:
                print(file_name)
                # new_labels_val = "../../VOC/labels/val"
                # new_images_val = "../../VOC/images/val"
                # create_dir_not_exist(new_labels_val)
                # create_dir_not_exist(new_images_val)

                if change_xml2txt(anno_path, xml, "../VOC/labels/val"):
                    # img = cv2.imread("{}/{}".format(img_path, file_name))
                    # y, x, _ = img.shape
                    # img = cv2.resize(img, (int(x/3), int(y/3)))
                    # cv2.imwrite("./data/val/{}".format(file_name), img)
                    # print("file_name:", file_name)
                    copyfile("{}/{}".format(img_path, file_name), "../VOC/images/val/{}".format(file_name))
                
                
    
if __name__ == "__main__":
    # if where != None:
    #     system("sshpass -p 123456 scp -r showhand@10.10.20.204:~/dataset/barcode/{}/barcode/ ../data/{}/".format(where,where))
    if not only_val:
        system("rm -f ../VOC/labels/train/*")
        system("rm -f ../VOC/images/train/*")
        system("rm -f ../VOC/labels/val/*")
        system("rm -f ../VOC/images/val/*")
    move_data_to_ready_data()
    # create_main()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值