from os import listdir, system
from os.path import isfile, isdir, join
import random, yaml
import sys, cv2, os
from shutil import copyfile
from distutils.dir_util import copy_tree
from xml.dom import minidom
only_val = False
where = 'sh1'
if only_val:
data_from = [where]
else:
# data_from = ['sh1', 'ms_showroom', 'ms_1']
data_from = ['original']
train_rate = 0.9
y = open("../data/voc1.yaml")
cfg = yaml.load(y, Loader=yaml.FullLoader)
labels = cfg["names"]
def create_dir_not_exist(path):
if not os.path.exists(path):
os.mkdir(path)
def get_position(box, img_width, img_height):
y = int(float(box.getElementsByTagName('ymin')[0].firstChild.data))
x = int(float(box.getElementsByTagName('xmin')[0].firstChild.data))
xmax = int(float(box.getElementsByTagName('xmax')[0].firstChild.data))
ymax = int(float(box.getElementsByTagName('ymax')[0].firstChild.data))
width = int(xmax) - int(x)
height = int(ymax) - int(y)
center_x = x + int((int(xmax) - int(x))/2)
center_y = y + int((int(ymax) - int(y))/2)
return {'center_y': center_y/img_height, 'center_x': center_x/img_width, 'width': width/img_width, 'height': height/img_height}
def change_xml2txt(anno_path, xml_name, save_path):
xml = minidom.parse(anno_path + "/" + xml_name)
objs = xml.getElementsByTagName('object')
size = xml.getElementsByTagName('size')[0]
img_width = int(size.getElementsByTagName('width')[0].firstChild.data)
img_height = int(size.getElementsByTagName('height')[0].firstChild.data)
# if len(objs) <= 5:
# return False
txt = open("{}/{}".format(save_path, xml_name.replace('.xml','.txt')),'w')
for obj in objs:
box = obj.getElementsByTagName('bndbox')[0]
position = get_position(box, img_width, img_height)
name = obj.getElementsByTagName('name')[0].firstChild.data
try:
index = labels.index(name)
string = "{} {} {} {} {}\n".format(index, position['center_x'], position['center_y'], position['width'], position['height'])
txt.write(string)
except:
pass
txt.close()
return True
def move_data_to_ready_data():
for data in data_from:
img_path = "../data/{}/images".format(data)
# print("img_path:",img_path)
anno_path = "../data/{}/annotations".format(data)
img_names = listdir(img_path)
xmls = listdir(anno_path)
train_amount = int(len(img_names)*train_rate)
random.shuffle(img_names)
if only_val:
train_amount = 0
print("==================train data=======================")
for file_name in img_names[:train_amount]:
print(file_name)
# xml = file_name.replace("jpg", "xml")
if file_name.split(".")[1] == 'png':
xml = file_name.replace("png", "xml")
else:
xml = file_name.replace("jpg", "xml")
if xml in xmls:
# print("xml:", xml)
# print("anno_path:",anno_path)
# new_labels_train = "../../VOC/labels/train"
# new_images_train = "../../VOC/images/train"
# create_dir_not_exist(new_labels_train)
# create_dir_not_exist(new_images_train)
if change_xml2txt(anno_path, xml, "../VOC/labels/train"):
# img = cv2.imread("{}/{}".format(img_path, file_name))
# y, x, _ = img.shape
# img = cv2.resize(img, (int(x/3), int(y/3)))
# cv2.imwrite("./data/train/{}".format(file_name), img)
# print("file_name:", file_name)
copyfile("{}/{}".format(img_path, file_name), "../VOC/images/train/{}".format(file_name))
else:
print("------------false----------------")
print("==================val data=======================")
for file_name in img_names[train_amount:]:
print(file_name)
# xml = file_name.replace("jpg", "xml")
# xml = file_name.replace("jpg", "xml")
if file_name.split(".")[1] == 'png':
xml = file_name.replace("png", "xml")
else:
xml = file_name.replace("jpg", "xml")
if xml in xmls:
print(file_name)
# new_labels_val = "../../VOC/labels/val"
# new_images_val = "../../VOC/images/val"
# create_dir_not_exist(new_labels_val)
# create_dir_not_exist(new_images_val)
if change_xml2txt(anno_path, xml, "../VOC/labels/val"):
# img = cv2.imread("{}/{}".format(img_path, file_name))
# y, x, _ = img.shape
# img = cv2.resize(img, (int(x/3), int(y/3)))
# cv2.imwrite("./data/val/{}".format(file_name), img)
# print("file_name:", file_name)
copyfile("{}/{}".format(img_path, file_name), "../VOC/images/val/{}".format(file_name))
if __name__ == "__main__":
# if where != None:
# system("sshpass -p 123456 scp -r showhand@10.10.20.204:~/dataset/barcode/{}/barcode/ ../data/{}/".format(where,where))
if not only_val:
system("rm -f ../VOC/labels/train/*")
system("rm -f ../VOC/images/train/*")
system("rm -f ../VOC/labels/val/*")
system("rm -f ../VOC/images/val/*")
move_data_to_ready_data()
# create_main()
create_train_data_yolov5
最新推荐文章于 2022-08-25 13:47:02 发布