tensorflow2-yolov4-程序小白学习笔记
项目场景:tensorflow2-yolov4-程序学习笔记
提示:程序由bubbiiiing提供:链接: link.
本程序没有做修改,只是添加了自己的详细注释,作为自己的一个记录。
程序:voc2yolo4.py
本程序为:./VOCdevikit/voc2007/voc2yolo4.py
程序用途:将已经生成的xml文件(放置在:"./VOCdevikit/voc2007/Annotations")。分别归类到"./VOCdevikit/voc2007/ImageSets/Main"中的训练集、训练验证集、验证集、检测集中。
xmlfilepath = "./Annotations"
saveBasePath = "./ImageSets/Main/"
#----------------------------------------------------------------------#
# 想要增加测试集修改trainval_percent
# train_percent不需要修改
#----------------------------------------------------------------------#
trainval_percent = 1 # 设置全部训练集的数量的x%放置在trainval中,(1-x)放置在test中。
train_percent = 1 # 首先trainval与train内容一样,train中保留x,其余的放置到val中。
# 单独提取出xxx.xml类型文件名
temp_xml = os.listdir(xmlfilepath) # 返回路径下所有文件名称包括文件类型名
total_xml = []
for xml in temp_xml: # 整合文件下的xml文件名,放入total_xml中
if xml.endswith(".xml"):
total_xml.append(xml)
# print(total_xml)
num = len(total_xml) # 返回xml文件数量
list = range(num) # list为文件数量变range格式
tv = int(num*trainval_percent)
tr = int(tv*train_percent)
trainval = random.sample(list, tv) # 截取list中随机长度为tv的元素,将range中的数打乱截取tv个出来,代表xml文件。
train = random.sample(trainval, tr) # 同上,在随机种子已经设置好。在trainval中提取tr个出来,乱序。
print("train and val size", tv)
print("traub suze", tr)
ftrainval = open(os.path.join(saveBasePath, 'trainval.txt'), 'w') # os.path.join拼接路径(自动加/)
ftest = open(os.path.join(saveBasePath, 'test.txt'), 'w') # open(xxx.'w')打开只用于写入文件
ftrain = open(os.path.join(saveBasePath, 'train.txt'), 'w')
fval = open(os.path.join(saveBasePath, 'val.txt'), 'w')
for i in list:
name = total_xml[i][:-4]+'\n' # 提取标签文件名,去掉文件类型文字。
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest .close()
程序:voc_annotation.py
本程序为:./annotation.py
程序用途:整合图片文件路径和与之对应的xml文件信息,按照上面的程序划分的训练集、验证集、测试集。
输出结构格式为:图片路径+(bndbox的四点信息+物体在classes定义中的排序号)*n(图片中存在物体的数量)。
保存到根目录2007—test,2007-train,2007-val文件中。为后续信息导入做准备。
import xml.etree.ElementTree as ET
from os import getcwd
sets=[('2007', 'train'), ('2007', 'val'), ('2007', 'test')] # 设置输出选项
classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
# 分类结果设置
def convert_annotation(year, image_id, list_file):
in_file = open('VOCdevkit/VOC%s/Annotations/%s.xml'%(year, image_id), encoding='utf-8')
tree=ET.parse(in_file) # 将xml文档解析为ElementTree对象
root = tree.getroot() # 获取element类的树根
for obj in root.iter('object'): # 遍历root下面的所有object项目
difficult = 0
if obj.find('difficult')!=None:
difficult = obj.find('difficult').text
cls = obj.find('name').text # 把这个object的‘name’中内容放入到cls中。(提出xml文件中物体的类型名)
if cls not in classes or int(difficult)==1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (int(xmlbox.find('xmin').text), int(xmlbox.find('ymin').text), int(xmlbox.find('xmax').text), int(xmlbox.find('ymax').text))
# 输出每一个object中bndbox的四个参数。
list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
# 生成图片文件的路径+图片中所有bnbbox的参数+name参数在classes中的序号
wd = getcwd() # 得到当前路径
for year, image_set in sets:
image_ids = open('VOCdevkit/VOC%s/ImageSets/Main/%s.txt'% (year, image_set)).read().strip().split()
# 打开文件,read()读取所有文件内容,strip为删除空白字符包括'\n', '\r', '\t', ' ',split为按照字符串分割
# 把xml文件名全部放入 image_ids列表
print(image_ids)
list_file = open('%s_%s.txt'%(year, image_set), 'w') # 生成2007—test,2007-train,2007-val文件
for image_id in image_ids:
list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s.jpg'%(wd, year, image_id)) # 生成xml文件对应图片路径
convert_annotation(year, image_id, list_file)
list_file.write('\n')
list_file.close()