关于tf-faster-rcnn处理的几个小代码

1.将自己的数据集的文件名统一修改为pascal voc的命名格式:
代码名称:xml_node.py

# coding:utf-8
import os
import os.path
import xml.etree.ElementTree as ET
from xml.etree.ElementTree import Element
from w3lib.html import remove_tags

findpath ='/media/ymw/LANKEXIN/02facexml'
filenames=os.listdir(findpath)
s=[]
save_path ='/media/ymw/LANKEXIN/02facexml'
for file in filenames:

    tree = ET.parse(os.path.join(findpath, file))
    root = tree.getroot()

    # 为图片改名字01face (1).jpg改成000001.jpg
    nodename = root[1].text
    re_nodename = nodename[:-4].replace('02face (', '').replace(')', '')
    new_num = int(re_nodename) + 1696
    new_nodename = str(new_num)
    new_nodename1 = new_nodename.zfill(6) + '.jpg'
    # new_nodename1 = re_nodename.zfill(6) + '.jpg'
    root[1].text = new_nodename1


    tree.write(os.path.join(save_path, file))

2.将自己的数据集标注格式转化成pascal voc 的标注格式:
代码名字:new_xml.py

# coding:utf-8
import os
import os.path
import xml.etree.ElementTree as ET
from xml.etree.ElementTree import Element
findpath = '/media/ymw/LANKEXIN/02facexml'
filenames = os.listdir(findpath)
s = []
xml_path = '/media/ymw/LANKEXIN/02facexml'
for file in filenames:

# 依次添加节点(按照pascal voc 的标注格式)
    tree = ET.parse(os.path.join(findpath, file))
    root = tree.getroot()
    source = Element('source')
    root.append(source)


    database = Element('database')
    database.text = 'The VOC2007 Database'
    source.append(database)
    annotation = Element('annotation')
    annotation.text = 'PASCAL VOC2007'
    source.append(annotation)
    image = Element('image')
    image.text = 'flickr'
    source.append(image)
    # 获取bndbox的节点信息
    object = root.getchildren()[3]
    face = object.getchildren()[0]
    big = face.getchildren()[1]

    # 将...作为object的子节点添加进去

    one = Element('name')
    one.text = 'cat'
    object.append(one)
    two = Element('pose')
    two.text = 'Unspecified'
    object.append(two)
    three = Element('truncated')
    three.text = '0'
    object.append(three)
    four = Element('difficult')
    four.text = '0'
    object.append(four)
    object.append(big)
    object.remove(face)
    # bndbox = ET.SubElement(face, 'bndbox')
    # bndbox.text = big.text
    nodename2 = root[3].tag
    root[3].tag = 'object'
    root[3][4].tag = 'bndbox'

# 将修改之后的标注写入到一个新的文件夹
    tree.write(os.path.join(xml_path, file))


3.从Annotations文件夹产生ImageSets,xml文件生成txt文件:
代码名字:xml2txt.py

# coding:utf-8
import os
import random

trainval_percent = 0.98  # 作为trainval_percent的比例,在整个数据集,可以修改
train_percent = 0.98  # 用于训练的数据的比例,可以修改
xmlfilepath = '/home/ymw/桌面/xinan/Cat Face/catxml_total'
txtsavepath = '/home/ymw/桌面/xinan/Cat Face/cattxt_total/'
total_xml = os.listdir(xmlfilepath)

num=len(total_xml)
list=range(num)
tv=int(num*trainval_percent)
tr=int(tv*train_percent)
trainval= random.sample(list,tv)
train=random.sample(trainval,tr)

ftrainval = open(txtsavepath+'/trainval.txt', 'w')
ftest = open(txtsavepath+'/test.txt', 'w')
ftrain = open(txtsavepath+'/train.txt', 'w')
fval = open(txtsavepath+'/val.txt', 'w')

for i in list:
    org_name=total_xml[i][:-4]+'\n'
    re_name=org_name.replace('01face (', '').replace(')', '')
    name=re_name.zfill(7)
    if i in trainval:
        ftrainval.write(name)
        if i in train:
            ftrain.write(name)
        else:
            fval.write(name)
    else:
        ftest.write(name)

ftrainval.close()
ftrain.close()
fval.close()
ftest .close()

4.可视化数据集检查错误:
代码名字:view.py

# -*- coding: UTF-8 -*-
from xml.etree import ElementTree as ET
import cv2
import os

imgpath = '/home/ymw/PycharmProjects/tf-faster-rcnn-master/data/VOCdevkit2007/VOC2007/JPEGImages'
xmlpath = '/home/ymw/桌面/facebig_xml'
imgnames = os.listdir(imgpath)
for imgname in imgnames:

    img = cv2.imread(os.path.join(imgpath, imgname))
    index = imgname[:-4]
    xmlname = index + '.xml'
   
    tree = ET.parse(os.path.join(xmlpath, xmlname))
    root = tree.getroot()
    bndbox = root.getchildren()[3].getchildren()[4]
  
    big_xmin = int(bndbox.getchildren()[0].text)
    big_ymin = int(bndbox.getchildren()[1].text)
    big_xmax = int(bndbox.getchildren()[2].text)
    big_ymax = int(bndbox.getchildren()[3].text)
       
    cv2.rectangle(img, (big_xmin, big_ymin), (big_xmax, big_ymax), (0, 0, 255), 2)

    cv2.imwrite(os.path.join('/home/ymw/PycharmProjects/tf-faster-rcnn-master/data/VOCdevkit2007/VOC2007/viewimgs',imgname), img)

5.检查自己的标注是否有坐标值为0(在tf-faster-rcnn中有影响,导致程序报错:total loss = Nan):

#coding:utf-8
from xml.etree import ElementTree as ET

import os
import shutil,re

xmlpath = '/media/ymw/LANKEXIN/facebig_xml_right'
xmlnames = os.listdir(xmlpath)
i = 0
a = []
# f = open('/home/ymw/桌面/1.txt', 'w')
for xmlname in xmlnames:
    # f.writelines(xmlname + '\n')
    tree = ET.parse(os.path.join(xmlpath, xmlname))
    root = tree.getroot()
    small = root.getchildren()[3].getchildren()[0].getchildren()[0]
    big = root.getchildren()[3].getchildren()[4]
    # object = tree.findall('object')

    big_xmin = int(big.getchildren()[0].text)
    big_ymin = int(big.getchildren()[1].text)
    big_xmax = int(big.getchildren()[2].text)
    big_ymax = int(big.getchildren()[3].text)
    small_xmin = int(small.getchildren()[0].text)
    small_ymin = int(small.getchildren()[1].text)
    small_xmax = int(small.getchildren()[2].text)
    small_ymax = int(small.getchildren()[3].text)

    if big_xmin > small_xmin:
        i = i+1
       # newpath = '/media/ymw/LANKEXIN/02facexml/02face_cuowu'
       # os.remove(os.path.join(xmlpath, xmlname))

    #     xmlnum = xmlname[:-4].replace('02face (', '').replace(')', '')
    #
    #
        print(i)
        print(xmlname)

该段程序会筛选出自己的数据集中坐标值为0的图片名称,i表示个数,xmlname就是所要筛选出来的文件名字。
6.测试集图片可视化处理:
实现将测试机图片copy到data/demo文件夹中替换原本的测试图片,并将所有的测试集图片名称输出到一个列表中,以此完成测试集图片的可视化.
代码名称:copytestimg.py

import os
import shutil

# img_txtpath是voc数据集中ImageSets中的text.txt,以此找到测试集图片的索引值
img_txtpath ='tf-faster-rcnn-master/data/VOCdevkit2007/VOC2007/ImageSets/Main/test.txt'
# dir是.jpg文件所在地址
dir = 'tf-faster-rcnn-master/data/VOCdevkit2007/VOC2007/JPEGImages'
# newpath是将测试集图片copy到的目的地
newpath = '/home/ymw/桌面/demo_test'

with open(img_txtpath) as f:
    img_index = [x.strip()+'.jpg' for x in f.readlines()]  # 这个是参照pascal_voc.py中读取图片的索引值的处理方法
                                                           # 将所有的索引值+'.jpg'都加入到一个列表中,刚好可以放到demo.py中
    # print(img_index)
    for file in img_index:
        img_path = os.path.join(dir, file)
        shutil.copy(img_path, newpath)  # 实现测试集图片的copy

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值