Street View Text数据集图像来源自于Google Street View,数据集种的图像包含好质量和低质量的图像,通常低质量图片居多。下载的数据集种包含两个文件,train.xml和test.xml文件。格式如下:
<?xml version="1.0" encoding="utf-8"?>
<tagset>
<image>
<imageName>img/14_03.jpg</imageName>
<address>341 Southwest 10th Avenue Portland OR</address>
<lex>LIVING,ROOM,THEATERS,KENNY,ZUKE,DELICATESSEN,CLYDE,COMMON,ACE,HOTEL,PORTLAND,ROSE,CITY,BOOKS,STUMPTOWN,COFFEE,ROASTERS,RED,CAP,GARAGE,FISH,GROTTO,SEAFOOD,RESTAURANT,AURA,RESTAURANT,LOUNGE,ROCCO,PIZZA,PASTA,BUFFALO,EXCHANGE,MARK,SPENCER,LIGHT,FEZ,BALLROOM,READING,FRENZY,ROXY,SCANDALS,MARTINOTTI,CAFE,DELI,CROWSENBERG,HALF</lex>
<Resolution x="1280" y="880"/>
<taggedRectangles>
<taggedRectangle height="75" width="236" x="375" y="253">
<tag>LIVING</tag>
</taggedRectangle>
<taggedRectangle height="76" width="175" x="639" y="272">
<tag>ROOM</tag>
</taggedRectangle>
<taggedRectangle height="87" width="281" x="839" y="283">
<tag>THEATERS</tag>
</taggedRectangle>
</taggedRectangles>
</image>
这种显然不能直接运用,需要对数据集进行处理,程序里包含两个函数read_SVT_dataset和crop函数,前者是要将文字所在的区域信息提取到txt文件中,后者则是从原图像中剪裁出并保存图像。txt文件保存形式如下:
dataset/img/14_03.jpg # 原图像的路径 3 # 原图像中识别单词的数目 LIVING,375,253,236,75, # 识别单词,bounding box[x, y, width, height] ROOM,639,272,175,76, THEATERS,839,283,281,87,
程序全部代码如下:
import os
import sys
import tensorflow as tf
from PIL import Image
from xml.etree import ElementTree
train_XML_src = 'dataset/train.xml'
test_XML_src = 'dataset/test.xml'
image_src = 'dataset/img/'
def read_SVT_dataset(XML_src, write_txt, image_src):
'''从xml文件中提出图片信息写入txt文件中
@param XML_src: xml文件路径
@param write_txt: 提取图片的信息
@param image_src: 图片文件所在文件夹
'''
train_fd = open(write_txt, 'w')
with open(XML_src) as f:
tree = ElementTree.parse(f)
for node in tree.iter('image'):
img_name = []
for each_image in node:
if each_image.tag == 'imageName':
# 记录保存图像名称
img_name.append(each_image.text[4::])
train_fd.write(image_src + each_image.text[4::] + '\n')
if each_image.tag == 'taggedRectangles':
# 计算矩形的数目
count = 0
x, y, height, width, name = [], [], [], [], []
for each_taggedRec in each_image:
count = count + 1
# 获取坐标信息,得到的是字典
tmp_dict = each_taggedRec.attrib
x.append(tmp_dict['x'])
y.append(tmp_dict['y'])
height.append(tmp_dict['height'])
width.append(tmp_dict['width'])
name.append(each_taggedRec.find('tag').text)
train_fd.write(str(count) + '\n')
for i in range(len(x)):
train_fd.write(str(name[i]) + ',' + str(x[i]) + ',' + str(y[i]) + ',' + str(width[i]) + ',' +
str(height[i]) + ',' + '\n')
train_fd.close()
def crop(text):
'''从txt文本文件中提取信息,剪裁图片
@param text: txt文件路径
@return:
'''
file = open(text, 'r')
while True:
line = file.readline()
if not line:
break
line = line.split('\n')
image = Image.open(line[0])
count = int(file.readline())
for i in range(count):
string = file.readline()
list = string.split(',')
img_name = list[0]
x, y, width, height = int(list[1]), int(list[2]), int(list[3]), int(list[4])
crop_img = image.crop([x, y, x + width, y + height])
crop_img.save('./dataset/images/' + img_name + '.jpg')
def main(argv):
read_SVT_dataset(train_XML_src, 'train.txt', image_src)
read_SVT_dataset(test_XML_src, 'test.txt', image_src)
crop('train.txt')
crop('test.txt')
if __name__ == '__main__':
tf.app.run()
最后得到剪裁的图片如下图所示:
总共有350张原图像,得到571张剪裁的图片