将数据集NWPU VHR-10转成pascal voc的格式

最近开始鼓捣TensorFlow object detection api(以下简称TFODA)相关的东西,主要面向遥感数据。训练无外乎是数据集、预训练模型和训练平台。由于是刚接触python、tensorflow相关内容,为了降低难度,当然是越少diy越好,尽量先参考现有成功方案再进一步自定义训练过程和内容。能用现有方案就用现有方案,先把平台搭建,预训练模型选取,数据集准备,训练,导出训练模型并进行object detection这一整套流程走通再说。再以后为了满足自己需求的修改就是后文了。毕竟,走通才是关键,才有信心进行以后的修改或优化。

整套解决方案实施过程中,主要参考了《21个项目玩转深度学习:基于TensorFlow的实践详解》一书中的第5章尤其是第5.2节的内容。(在此,感谢该书作者何之源老师)。方案中主要涉及预训练模型、数据集和训练平台(或框架)。预训练模型和平台,准备就先用书中提到的。该书中的训练数据集不是面向遥感方面的,所以需要寻找遥感相关的数据。

找了许多遥感相关数据,但是都有各自的数据格式,包括图片格式,txt内容格式,xml内容格式等等。为了适应TFODA平台(或者说框架更为妥当),需要将数据集转换为平台接受的格式,也即Pacal VOC或者直接是tfrecord格式。       

由于遥感数据集很多,结合数据集大小(有的数据集太大,下载太慢,怪我网络太不给力)和现有从数据集转换为TFODA平台能接受的格式的现有方案成熟度(不是每种数据集都有成熟方案很方便的转换为VOC格式或者TFRecord格式)等各方面原因,最后选中了:

(1条消息)航空遥感图像(Aerial Images)目标检测数据集汇总 - 蚂蚁搬家 - CSDN博客
https://blog.csdn.net/hongxingabc/article/details/78833485
文章中提到的第3个例子:
3,NWPU VHR-10:西北工业大学标注的航天遥感目标检测数据集,共有800张图像,其中包含目标的650张,背景图像150张,目标包括:飞机、舰船、油罐、棒球场、网球场、篮球场、田径场、港口、桥梁、车辆10个类别。开放下载,大概73M.

相关链接:http://jiong.tea.ac.cn/people/JunweiHan/NWPUVHR10dataset.html

数据下载地址,http://pan.baidu.com/s/1hqwzXeG。

 

       NWPU VHR-10数据集很小,70多M。接下来要做的就是将NWPU VHR-10向标准Pascal voc格式(2007或和2012都一样)再转换为TFRecord格式或者直接向TFrecord转换。

开始为了省事儿尝试了:

TFrecords类型数据集制作与读取(NWPU VHR-10数据集为例) - Puremelo - CSDN博客
https://blog.csdn.net/qq_39858278/article/details/85112547

所用的方法,即直接转换为TFRecord格式,能够成功转换为tfrecord格式,但是TFODA平台再使用该tfrecord文件的时候总是报错,而且错误一直没能成功解决,遂放弃了该方案。转而采用转换成voc格式再转为tfrecord格式。

Pascal VOC 官方数据集voc2012下主要包含Annotations、ImageSets、JPEGImages、SegmentationClass、SegmentationObject五个文件夹。在Object detection任务(除了目标检测任务还有目标识别,目标分割等任务)中主要用到了前三个,而且这三个中的ImageSets中也主要用了Main文件夹下的内容。JPEGImages则直接将NWPU VHR-10数据集中的positive image set文件夹下的图片整体拷贝过来即可。接着我会描述Annotations和ImageSets文件夹中的内容是如何生成的。

Annotations文件夹中的内容的生成主要参考:

把数据集NWPU VHR-10转成pascal voc的格式 - summer2day的博客 - CSDN博客
https://blog.csdn.net/summer2day/article/details/83064727

修改了部分代码(主要涉及文件读取或者生成路径),全部代码如下:

from lxml.etree import Element, SubElement, tostring
from xml.dom.minidom import parseString
import xml.dom.minidom
import os
import sys
from PIL import Image

#https://blog.csdn.net/summer2day/article/details/83064727#comments

# 把txt中的内容写进xml
def deal(path):
    files = os.listdir(path)  # 列出所有文件
    for file in files:
        filename = os.path.splitext(file)[0]  # 分割出文件名
        # print(filename)
        sufix = os.path.splitext(file)[1]  # 分割出后缀
        if sufix == '.txt':
            xmins = []
            ymins = []
            xmaxs = []
            ymaxs = []
            names = []
            num, xmins, ymins, xmaxs, ymaxs, names = readtxt(file)
            # dealpath = path + "/" + filename + ".xml"
            dealpath = xmlPath + "/" + filename + ".xml"
            filename = filename + '.jpg'
            with open(dealpath, 'w') as f:
                writexml(dealpath, filename, num, xmins, ymins, xmaxs, ymaxs, names)


# 读取图片的高和宽写入xml
def dealwh(path):
    files = os.listdir(path)  # 列出所有文件
    for file in files:
        filename = os.path.splitext(file)[0]  # 分割出文件名
        sufix = os.path.splitext(file)[1]  # 分割出后缀
        if sufix == '.jpg':
            height, width = readsize(file)
            # dealpath = path + "/" + filename + ".xml"
            dealpath = xmlPath + "/" + filename + ".xml"
            gxml(dealpath, height, width)


# 读取txt文件
def readtxt(p):
    p_file = txtPath + "/" + p
    with open(p_file, 'r') as f:
        contents = f.read()
        # print(contents)
        objects = contents.split('\n')  # 分割出每个物体
        for i in range(objects.count('')):  # 去掉空格项
            objects.remove('')
        # print(objects)
        num = len(objects)  # 物体的数量
        # print(num)
        xmins = []
        ymins = []
        xmaxs = []
        ymaxs = []
        names = []
        for objecto in objects:
            # print(objecto)
            xmin = objecto.split(',')[0]
            xmin = xmin.split('(')[1]
            xmin = xmin.strip()

            ymin = objecto.split(',')[1]
            ymin = ymin.split(')')[0]
            ymin = ymin.strip()

            xmax = objecto.split(',')[2]
            xmax = xmax.split('(')[1]
            xmax = xmax.strip()

            ymax = objecto.split(',')[3]
            ymax = ymax.split(')')[0]
            ymax = ymax.strip()

            name = objecto.split(',')[4]
            name = name.strip()

            if name == "1 " or name == "1":
                name = 'airplane'
            elif name == "2 " or name == "2":
                name = 'ship'
            elif name == "3 " or name == "3":
                name = 'storage tank'
            elif name == "4 " or name == "4":
                name = 'baseball diamond'
            elif name == "5 " or name == "5":
                name = 'tennis court'
            elif name == "6 " or name == "6":
                name = 'basketball court'
            elif name == "7 " or name == "7":
                name = 'ground track field'
            elif name == "8 " or name == "8":
                name = 'habor'
            elif name == "9 " or name == "9":
                name = 'bridge'
            elif name == "10 " or name == "10":
                name = 'vehicle'
            else:
                print(txtPath)
            # print(xmin,ymin,xmax,ymax,name)
            xmins.append(xmin)
            ymins.append(ymin)
            xmaxs.append(xmax)
            ymaxs.append(ymax)
            names.append(name)
        # print(num,xmins,ymins,xmaxs,ymaxs,names)
        return num, xmins, ymins, xmaxs, ymaxs, names


# 在xml文件中添加宽和高
def gxml(path, height, width):
    dom = xml.dom.minidom.parse(path)
    root = dom.documentElement
    heights = root.getElementsByTagName('height')[0]
    heights.firstChild.data = height
    # print(height)

    widths = root.getElementsByTagName('width')[0]
    widths.firstChild.data = width
    # print(width)
    with open(path, 'w') as f:
    # with open(xmlPath, 'w') as f:
        dom.writexml(f)
    return


# 创建xml文件
def writexml(path, filename, num, xmins, ymins, xmaxs, ymaxs, names, height='256', width='256'):
    node_root = Element('annotation')

    node_folder = SubElement(node_root, 'folder')
    node_folder.text = "VOC2007"

    node_filename = SubElement(node_root, 'filename')
    node_filename.text = "%s" % filename

    node_size = SubElement(node_root, "size")
    node_width = SubElement(node_size, 'width')
    node_width.text = '%s' % width

    node_height = SubElement(node_size, 'height')
    node_height.text = '%s' % height

    node_depth = SubElement(node_size, 'depth')
    node_depth.text = '3'
    for i in range(num):
        node_object = SubElement(node_root, 'object')
        node_name = SubElement(node_object, 'name')
        node_name.text = '%s' % names[i]
        node_name = SubElement(node_object, 'pose')
        node_name.text = '%s' % "unspecified"
        node_name = SubElement(node_object, 'truncated')
        node_name.text = '%s' % "0"
        node_difficult = SubElement(node_object, 'difficult')
        node_difficult.text = '0'
        node_bndbox = SubElement(node_object, 'bndbox')
        node_xmin = SubElement(node_bndbox, 'xmin')
        node_xmin.text = '%s' % xmins[i]
        node_ymin = SubElement(node_bndbox, 'ymin')
        node_ymin.text = '%s' % ymins[i]
        node_xmax = SubElement(node_bndbox, 'xmax')
        node_xmax.text = '%s' % xmaxs[i]
        node_ymax = SubElement(node_bndbox, 'ymax')
        node_ymax.text = '%s' % ymaxs[i]

    xml = tostring(node_root, pretty_print=True)
    dom = parseString(xml)
    with open(path, 'wb') as f:
        f.write(xml)
    return


def readsize(p):
    p_file=imagePath+"/"+p
    img=Image.open(p_file)
    width = img.size[0]
    height = img.size[1]
    return height, width

if __name__ == "__main__":
    # path = ("D:/NWPU VHR-10 dataset/NWPU VHR-10 dataset/test")
    imagePath = ("./NWPU VHR-10 dataset/positive image set")
    txtPath = ("./NWPU VHR-10 dataset/ground truth")
    xmlPath = ("./NWPU VHR-10 dataset/annotations")
    deal(txtPath)
    dealwh(imagePath)

该代码主要生成pascal voc官方数据集中的 annotations目录下对应的内容,也即xml文件。

 

ImageSets/Main文件夹下的部分txt文件则参考以下内容生成:

目标检测之VOC2007格式数据集制作 - duanyajun987的博客 - CSDN博客
https://blog.csdn.net/duanyajun987/article/details/81507656

代码稍作修改,内容如下:

import os  
import random  
 
#https://blog.csdn.net/duanyajun987/article/details/81507656

trainval_percent = 0.5  
train_percent = 0.5  
xmlfilepath = 'Annotations'  
txtsavepath = 'ImageSets/Main'  
total_xml = os.listdir(xmlfilepath)  
 
num=len(total_xml)  
list=range(num)  
tv=int(num*trainval_percent)  
tr=int(tv*train_percent)  
trainval= random.sample(list,tv)  
train=random.sample(trainval,tr)  
 
ftrainval = open(txtsavepath+'/trainval.txt', 'w')  
ftest = open(txtsavepath+'/test.txt', 'w')  
ftrain = open(txtsavepath+'/train.txt', 'w')  
fval = open(txtsavepath+'/val.txt', 'w')  
 
for i  in list:  
    name=total_xml[i][:-4]+'\n'  
    if i in trainval:  
        ftrainval.write(name)  
        if i in train:  
            ftrain.write(name)  
        else:  
            fval.write(name)  
    else:  
        ftest.write(name)  
 
ftrainval.close()  
ftrain.close()  
fval.close()  
ftest .close()

这里生成了train.txt、val.txt、trainval.txt及test.txt文件。ImageSets/Main文件夹下的***_train.txt、***_trainval.txt、***_val.txt一系列文件则暂未找到生成方法。在通过平台下的create_pascal_tf_record.py(\models\research\object_detection\dataset_tools文件夹下)将pascal voc格式的数据转换为tfrecord的时候如果没有aeroplane_train.txt、aeroplane_val.txt,会提示找不到这些文件的错误。多方查找资料,发现有些人说只用到该文件的第一列,第二列的正负1没用到:

python - Create PASCAL Voc for Tensorflow Object Detection API - Stack Overflow
https://stackoverflow.com/questions/44891732/create-pascal-voc-for-tensorflow-object-detection-api

观察所有***_train.txt、***_trainval.txt、***_val.txt,相应的文件的第一列都是一样的。并且create_pascal_tf_record.py也只读取了aeroplane这一类相应***_train.txt、***_val.txt文件,即aeroplane_train.txt、aeroplane_trainval.txt、aeroplane_val.txt。而且第一列内容跟train.txt、val.txt、trainval.txt也是一样的。所以,我尝试手动将train.txt、val.txt、trainval.txt复制了一份,改名为aeroplane_train.txt、aeroplane_trainval.txt、aeroplane_val.txt。然后create_pascal_tf_record.py顺利运行了,也即完成了从pascal voc 到 tfrecord的转换,并且转换后的tfrecord能够正常用于训练。这里特别提醒一句,如果想把自己的数据转换为标准pascal voc格式,而且分类跟pascal voc不一致,则应该把aeroplane改为自己相应的类别名,并且要把create_pascal_tf_record.py中相应的aeroplane字段也改为自己的类别名。

另外,在调用create_pascal_tf_record.py转换数据的时候,用到了 models/research/object_detection/data/pascal_label_map.pbtxt 映射文件,这一点在何之源老师的书中没有强调。但是在数据转化的时候,这个文件很重要。使用自己的数据,而且分类跟voc官方分类不一样的时候,就需要修改该文件(建议复制一份,放到其他地方,然后手动(分类少的话,能很快修改)修改相应内容,并且相应修改create_pascal_tf_record.py中关于该文件的路径)。我在讲NWPU VHR-10转成pascal voc时手动修改生成的pascal_label_map.pbtxt文件内容为:

item {
  id: 1
  name: 'airplane'
}

item {
  id: 2
  name: 'ship'
}

item {
  id: 3
  name: 'storage tank'
}

item {
  id: 4
  name: 'baseball diamond'
}

item {
  id: 5
  name: 'tennis court'
}

item {
  id: 6
  name: 'basketball court'
}

item {
  id: 7
  name: 'ground track field'
}

item {
  id: 8
  name: 'habor'
}

item {
  id: 9
  name: 'bridge'
}

item {
  id: 10
  name: 'vehicle'
}

至此,将数据集NWPU VHR-10转成pascal voc的格式基本完活。能力有限,留有小的遗憾,那就是一直没弄明白的是***_train.txt、***_trainval.txt、***_val.txt中的第二列在训练中究竟用没用到,用到的话,有什么影响。能加快训练收敛速度或者提高准确率、回收率啥的吗?有待大神帮忙解决。在此先谢过了。

感谢文中引用的各位大神的文章,多谢各位的分享。

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值