最近开始鼓捣TensorFlow object detection api(以下简称TFODA)相关的东西,主要面向遥感数据。训练无外乎是数据集、预训练模型和训练平台。由于是刚接触python、tensorflow相关内容,为了降低难度,当然是越少diy越好,尽量先参考现有成功方案再进一步自定义训练过程和内容。能用现有方案就用现有方案,先把平台搭建,预训练模型选取,数据集准备,训练,导出训练模型并进行object detection这一整套流程走通再说。再以后为了满足自己需求的修改就是后文了。毕竟,走通才是关键,才有信心进行以后的修改或优化。
整套解决方案实施过程中,主要参考了《21个项目玩转深度学习:基于TensorFlow的实践详解》一书中的第5章尤其是第5.2节的内容。(在此,感谢该书作者何之源老师)。方案中主要涉及预训练模型、数据集和训练平台(或框架)。预训练模型和平台,准备就先用书中提到的。该书中的训练数据集不是面向遥感方面的,所以需要寻找遥感相关的数据。
找了许多遥感相关数据,但是都有各自的数据格式,包括图片格式,txt内容格式,xml内容格式等等。为了适应TFODA平台(或者说框架更为妥当),需要将数据集转换为平台接受的格式,也即Pacal VOC或者直接是tfrecord格式。
由于遥感数据集很多,结合数据集大小(有的数据集太大,下载太慢,怪我网络太不给力)和现有从数据集转换为TFODA平台能接受的格式的现有方案成熟度(不是每种数据集都有成熟方案很方便的转换为VOC格式或者TFRecord格式)等各方面原因,最后选中了:
(1条消息)航空遥感图像(Aerial Images)目标检测数据集汇总 - 蚂蚁搬家 - CSDN博客
https://blog.csdn.net/hongxingabc/article/details/78833485
文章中提到的第3个例子:
3,NWPU VHR-10:西北工业大学标注的航天遥感目标检测数据集,共有800张图像,其中包含目标的650张,背景图像150张,目标包括:飞机、舰船、油罐、棒球场、网球场、篮球场、田径场、港口、桥梁、车辆10个类别。开放下载,大概73M.
相关链接:http://jiong.tea.ac.cn/people/JunweiHan/NWPUVHR10dataset.html
数据下载地址,http://pan.baidu.com/s/1hqwzXeG。
NWPU VHR-10数据集很小,70多M。接下来要做的就是将NWPU VHR-10向标准Pascal voc格式(2007或和2012都一样)再转换为TFRecord格式或者直接向TFrecord转换。
开始为了省事儿尝试了:
所用的方法,即直接转换为TFRecord格式,能够成功转换为tfrecord格式,但是TFODA平台再使用该tfrecord文件的时候总是报错,而且错误一直没能成功解决,遂放弃了该方案。转而采用转换成voc格式再转为tfrecord格式。
Pascal VOC 官方数据集voc2012下主要包含Annotations、ImageSets、JPEGImages、SegmentationClass、SegmentationObject五个文件夹。在Object detection任务(除了目标检测任务还有目标识别,目标分割等任务)中主要用到了前三个,而且这三个中的ImageSets中也主要用了Main文件夹下的内容。JPEGImages则直接将NWPU VHR-10数据集中的positive image set文件夹下的图片整体拷贝过来即可。接着我会描述Annotations和ImageSets文件夹中的内容是如何生成的。
Annotations文件夹中的内容的生成主要参考:
把数据集NWPU VHR-10转成pascal voc的格式 - summer2day的博客 - CSDN博客
https://blog.csdn.net/summer2day/article/details/83064727
修改了部分代码(主要涉及文件读取或者生成路径),全部代码如下:
from lxml.etree import Element, SubElement, tostring
from xml.dom.minidom import parseString
import xml.dom.minidom
import os
import sys
from PIL import Image
#https://blog.csdn.net/summer2day/article/details/83064727#comments
# 把txt中的内容写进xml
def deal(path):
files = os.listdir(path) # 列出所有文件
for file in files:
filename = os.path.splitext(file)[0] # 分割出文件名
# print(filename)
sufix = os.path.splitext(file)[1] # 分割出后缀
if sufix == '.txt':
xmins = []
ymins = []
xmaxs = []
ymaxs = []
names = []
num, xmins, ymins, xmaxs, ymaxs, names = readtxt(file)
# dealpath = path + "/" + filename + ".xml"
dealpath = xmlPath + "/" + filename + ".xml"
filename = filename + '.jpg'
with open(dealpath, 'w') as f:
writexml(dealpath, filename, num, xmins, ymins, xmaxs, ymaxs, names)
# 读取图片的高和宽写入xml
def dealwh(path):
files = os.listdir(path) # 列出所有文件
for file in files:
filename = os.path.splitext(file)[0] # 分割出文件名
sufix = os.path.splitext(file)[1] # 分割出后缀
if sufix == '.jpg':
height, width = readsize(file)
# dealpath = path + "/" + filename + ".xml"
dealpath = xmlPath + "/" + filename + ".xml"
gxml(dealpath, height, width)
# 读取txt文件
def readtxt(p):
p_file = txtPath + "/" + p
with open(p_file, 'r') as f:
contents = f.read()
# print(contents)
objects = contents.split('\n') # 分割出每个物体
for i in range(objects.count('')): # 去掉空格项
objects.remove('')
# print(objects)
num = len(objects) # 物体的数量
# print(num)
xmins = []
ymins = []
xmaxs = []
ymaxs = []
names = []
for objecto in objects:
# print(objecto)
xmin = objecto.split(',')[0]
xmin = xmin.split('(')[1]
xmin = xmin.strip()
ymin = objecto.split(',')[1]
ymin = ymin.split(')')[0]
ymin = ymin.strip()
xmax = objecto.split(',')[2]
xmax = xmax.split('(')[1]
xmax = xmax.strip()
ymax = objecto.split(',')[3]
ymax = ymax.split(')')[0]
ymax = ymax.strip()
name = objecto.split(',')[4]
name = name.strip()
if name == "1 " or name == "1":
name = 'airplane'
elif name == "2 " or name == "2":
name = 'ship'
elif name == "3 " or name == "3":
name = 'storage tank'
elif name == "4 " or name == "4":
name = 'baseball diamond'
elif name == "5 " or name == "5":
name = 'tennis court'
elif name == "6 " or name == "6":
name = 'basketball court'
elif name == "7 " or name == "7":
name = 'ground track field'
elif name == "8 " or name == "8":
name = 'habor'
elif name == "9 " or name == "9":
name = 'bridge'
elif name == "10 " or name == "10":
name = 'vehicle'
else:
print(txtPath)
# print(xmin,ymin,xmax,ymax,name)
xmins.append(xmin)
ymins.append(ymin)
xmaxs.append(xmax)
ymaxs.append(ymax)
names.append(name)
# print(num,xmins,ymins,xmaxs,ymaxs,names)
return num, xmins, ymins, xmaxs, ymaxs, names
# 在xml文件中添加宽和高
def gxml(path, height, width):
dom = xml.dom.minidom.parse(path)
root = dom.documentElement
heights = root.getElementsByTagName('height')[0]
heights.firstChild.data = height
# print(height)
widths = root.getElementsByTagName('width')[0]
widths.firstChild.data = width
# print(width)
with open(path, 'w') as f:
# with open(xmlPath, 'w') as f:
dom.writexml(f)
return
# 创建xml文件
def writexml(path, filename, num, xmins, ymins, xmaxs, ymaxs, names, height='256', width='256'):
node_root = Element('annotation')
node_folder = SubElement(node_root, 'folder')
node_folder.text = "VOC2007"
node_filename = SubElement(node_root, 'filename')
node_filename.text = "%s" % filename
node_size = SubElement(node_root, "size")
node_width = SubElement(node_size, 'width')
node_width.text = '%s' % width
node_height = SubElement(node_size, 'height')
node_height.text = '%s' % height
node_depth = SubElement(node_size, 'depth')
node_depth.text = '3'
for i in range(num):
node_object = SubElement(node_root, 'object')
node_name = SubElement(node_object, 'name')
node_name.text = '%s' % names[i]
node_name = SubElement(node_object, 'pose')
node_name.text = '%s' % "unspecified"
node_name = SubElement(node_object, 'truncated')
node_name.text = '%s' % "0"
node_difficult = SubElement(node_object, 'difficult')
node_difficult.text = '0'
node_bndbox = SubElement(node_object, 'bndbox')
node_xmin = SubElement(node_bndbox, 'xmin')
node_xmin.text = '%s' % xmins[i]
node_ymin = SubElement(node_bndbox, 'ymin')
node_ymin.text = '%s' % ymins[i]
node_xmax = SubElement(node_bndbox, 'xmax')
node_xmax.text = '%s' % xmaxs[i]
node_ymax = SubElement(node_bndbox, 'ymax')
node_ymax.text = '%s' % ymaxs[i]
xml = tostring(node_root, pretty_print=True)
dom = parseString(xml)
with open(path, 'wb') as f:
f.write(xml)
return
def readsize(p):
p_file=imagePath+"/"+p
img=Image.open(p_file)
width = img.size[0]
height = img.size[1]
return height, width
if __name__ == "__main__":
# path = ("D:/NWPU VHR-10 dataset/NWPU VHR-10 dataset/test")
imagePath = ("./NWPU VHR-10 dataset/positive image set")
txtPath = ("./NWPU VHR-10 dataset/ground truth")
xmlPath = ("./NWPU VHR-10 dataset/annotations")
deal(txtPath)
dealwh(imagePath)
该代码主要生成pascal voc官方数据集中的 annotations目录下对应的内容,也即xml文件。
ImageSets/Main文件夹下的部分txt文件则参考以下内容生成:
目标检测之VOC2007格式数据集制作 - duanyajun987的博客 - CSDN博客
https://blog.csdn.net/duanyajun987/article/details/81507656
代码稍作修改,内容如下:
import os
import random
#https://blog.csdn.net/duanyajun987/article/details/81507656
trainval_percent = 0.5
train_percent = 0.5
xmlfilepath = 'Annotations'
txtsavepath = 'ImageSets/Main'
total_xml = os.listdir(xmlfilepath)
num=len(total_xml)
list=range(num)
tv=int(num*trainval_percent)
tr=int(tv*train_percent)
trainval= random.sample(list,tv)
train=random.sample(trainval,tr)
ftrainval = open(txtsavepath+'/trainval.txt', 'w')
ftest = open(txtsavepath+'/test.txt', 'w')
ftrain = open(txtsavepath+'/train.txt', 'w')
fval = open(txtsavepath+'/val.txt', 'w')
for i in list:
name=total_xml[i][:-4]+'\n'
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest .close()
这里生成了train.txt、val.txt、trainval.txt及test.txt文件。ImageSets/Main文件夹下的***_train.txt、***_trainval.txt、***_val.txt一系列文件则暂未找到生成方法。在通过平台下的create_pascal_tf_record.py(\models\research\object_detection\dataset_tools文件夹下)将pascal voc格式的数据转换为tfrecord的时候如果没有aeroplane_train.txt、aeroplane_val.txt,会提示找不到这些文件的错误。多方查找资料,发现有些人说只用到该文件的第一列,第二列的正负1没用到:
python - Create PASCAL Voc for Tensorflow Object Detection API - Stack Overflow
https://stackoverflow.com/questions/44891732/create-pascal-voc-for-tensorflow-object-detection-api
观察所有***_train.txt、***_trainval.txt、***_val.txt,相应的文件的第一列都是一样的。并且create_pascal_tf_record.py也只读取了aeroplane这一类相应***_train.txt、***_val.txt文件,即aeroplane_train.txt、aeroplane_trainval.txt、aeroplane_val.txt。而且第一列内容跟train.txt、val.txt、trainval.txt也是一样的。所以,我尝试手动将train.txt、val.txt、trainval.txt复制了一份,改名为aeroplane_train.txt、aeroplane_trainval.txt、aeroplane_val.txt。然后create_pascal_tf_record.py顺利运行了,也即完成了从pascal voc 到 tfrecord的转换,并且转换后的tfrecord能够正常用于训练。这里特别提醒一句,如果想把自己的数据转换为标准pascal voc格式,而且分类跟pascal voc不一致,则应该把aeroplane改为自己相应的类别名,并且要把create_pascal_tf_record.py中相应的aeroplane字段也改为自己的类别名。
另外,在调用create_pascal_tf_record.py转换数据的时候,用到了 models/research/object_detection/data/pascal_label_map.pbtxt 映射文件,这一点在何之源老师的书中没有强调。但是在数据转化的时候,这个文件很重要。使用自己的数据,而且分类跟voc官方分类不一样的时候,就需要修改该文件(建议复制一份,放到其他地方,然后手动(分类少的话,能很快修改)修改相应内容,并且相应修改create_pascal_tf_record.py中关于该文件的路径)。我在讲NWPU VHR-10转成pascal voc时手动修改生成的pascal_label_map.pbtxt文件内容为:
item {
id: 1
name: 'airplane'
}
item {
id: 2
name: 'ship'
}
item {
id: 3
name: 'storage tank'
}
item {
id: 4
name: 'baseball diamond'
}
item {
id: 5
name: 'tennis court'
}
item {
id: 6
name: 'basketball court'
}
item {
id: 7
name: 'ground track field'
}
item {
id: 8
name: 'habor'
}
item {
id: 9
name: 'bridge'
}
item {
id: 10
name: 'vehicle'
}
至此,将数据集NWPU VHR-10转成pascal voc的格式基本完活。能力有限,留有小的遗憾,那就是一直没弄明白的是***_train.txt、***_trainval.txt、***_val.txt中的第二列在训练中究竟用没用到,用到的话,有什么影响。能加快训练收敛速度或者提高准确率、回收率啥的吗?有待大神帮忙解决。在此先谢过了。
感谢文中引用的各位大神的文章,多谢各位的分享。