TensorFlow杂记 - 生成和读取TFRecord(一)

利用闲暇时间,通过SSD-TensorFlow和网络资源上总结两种生成和读取TFRecord的方法,代码测试通过。

软件平台:Pycarn 2019.2 + Tensorflow 1.13.1 + cuda 10.0 + cudnn 7.6.5

GPU: 1080Ti

 

不啰嗦了,直接上代码,有问题可发送邮件 <tecsai@163.com>, 或直接留言讨论。

 

TFRecord生成:

import os
import tensorflow as tf 
from PIL import Image  # 注意Image,后面会用到
import xml.etree.ElementTree as ET

 
JpgFilePath='H:\\11_DataSet\\QD\\JPEGImages\\'
XmlFilePath='H:\\11_DataSet\\QD\\Annotations\\'
#文件路径
filepath = 'H:\\11_DataSet\\QD'
writer= tf.python_io.TFRecordWriter("H:\\11_DataSet\\QD\\qd.tfrecord")

VOC_LABELS = {
    'B01': (0, 'B01'),
    'D01': (1, 'D01'),
    'G01': (2, 'G01'),
    'W01': (3, 'W01'),
    'W02': (4, 'W02'),
    'W03': (5, 'W03'),
    'W04': (6, 'W04'),
    'T01': (7, 'T01'),
    'R01': (8, 'R01'),
    'F01': (9, 'F01'),
    'S01': (10, 'S01'),
    'G02': (11, 'G02'),
    'G03': (12, 'G03'),
    'W05': (13, 'W05'),
    'I18': (14, 'I18'),
    'I01': (15, 'I01'),
    'I02': (16, 'I02'),
    'I03': (17, 'I03'),
    'I04': (18, 'I04'),
    'I99': (19, 'I99'),
}

for img_name in os.listdir(JpgFilePath):
    img_path=JpgFilePath+img_name  # 每一个图片的地址

    # Jpeg
    img=Image.open(img_path)
    print(img_path)
    # img= img.resize((300, 300))
    image_data=img.tobytes()  # 将图片转化为二进制格式

    # Xml
    xml_name = img_name[:-4]  # 文件名,不带后缀
    xml_path = XmlFilePath+xml_name+'.xml'
    print(xml_path)
    tree = ET.parse(xml_path)
    root = tree.getroot()

    # Image shape.
    size = root.find('size')
    shape = [int(size.find('height').text), # shape是个list,shape[0]: height, shape[1]: width, shape[2]: depth
             int(size.find('width').text),
             int(size.find('depth').text)]
    # Find annotations.
    bboxes = []
    labels = []
    labels_text = []
    difficult = []
    truncated = []
    for obj in root.findall('object'):
        label = obj.find('name').text
        labels.append(int(VOC_LABELS[label][0]))
        labels_text.append(label.encode('ascii'))

        if obj.find('difficult'):
            difficult.append(int(obj.find('difficult').text))
        else:
            difficult.append(0)
        if obj.find('truncated'):
            truncated.append(int(obj.find('truncated').text))
        else:
            truncated.append(0)

        bbox = obj.find('bndbox')
        bboxes.append((float(bbox.find('ymin').text) / shape[0],
                       float(bbox.find('xmin').text) / shape[1],
                       float(bbox.find('ymax').text) / shape[0],
                       float(bbox.find('xmax').text) / shape[1]
                       ))
    print(labels)

    xmin = []
    ymin = []
    xmax = []
    ymax = []
    for b in bboxes:
        assert len(b) == 4
        # pylint: disable=expression-not-assigned
        [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]
        # pylint: enable=expression-not-assigned

    image_format = b'JPEG'
    example = tf.train.Example(features=tf.train.Features(feature={
            'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0]])),
            'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[1]])),
            'image/channels': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[2]])),
            'image/shape': tf.train.Feature(int64_list=tf.train.Int64List(value=shape)),
            'image/object/bbox/xmin': tf.train.Feature(float_list=tf.train.FloatList(value=xmin)),
            'image/object/bbox/xmax': tf.train.Feature(float_list=tf.train.FloatList(value=xmax)),
            'image/object/bbox/ymin': tf.train.Feature(float_list=tf.train.FloatList(value=ymin)),
            'image/object/bbox/ymax': tf.train.Feature(float_list=tf.train.FloatList(value=ymax)),
            'image/object/bbox/label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels)),
            'image/object/bbox/label_text': tf.train.Feature(bytes_list=tf.train.BytesList(value=labels_text)),
            'image/object/bbox/difficult': tf.train.Feature(int64_list=tf.train.Int64List(value=difficult)),
            'image/object/bbox/truncated': tf.train.Feature(int64_list=tf.train.Int64List(value=truncated)),
            'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_format])),
            'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))}))

    writer.write(example.SerializeToString())  # 序列化为字符串

writer.close()

TFRecord读取:

import os 
import tensorflow as tf 
from PIL import Image  

storepath='H:\\11_DataSet\\QD\\'

filename = 'H:\\11_DataSet\\QD\\qd.tfrecord'

filename_queue = tf.train.string_input_producer([filename])  # 生成一个queue队列
 
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)  # 返回文件名和文件
features = tf.parse_single_example(serialized_example,
                                   features={
                                            'image/shape': tf.FixedLenFeature([3], tf.int64),
                                            'image/height': tf.FixedLenFeature([1], tf.int64),
                                            'image/width': tf.FixedLenFeature([1], tf.int64),
                                            'image/encoded': tf.FixedLenFeature((), tf.string),
                                             })  # 将image数据和label取出来

image = tf.decode_raw(features['image/encoded'], tf.uint8)
image = tf.reshape(image, [1080, 1920, 3])
#image = tf.cast(img, tf.float32) * (1. / 255) - 0.5
shape = tf.cast(features['image/shape'], tf.int64)
height = tf.cast(features['image/height'], tf.int64)
width = tf.cast(features['image/width'], tf.int64)

with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(10):
        print("loop %d\n" % i)
        example, s, h, w = sess.run([image, shape, height, width])  # 在会话中取出image和label
        # h, w = sess.run([height, width])  # 在会话中取出image和label
        print(s)
        print(h)
        print(w)

        img=Image.fromarray(example, 'RGB')  # 这里Image是之前提到的
        img.save(storepath+str(i)+'.jpg') #  存下图片

    coord.request_stop()
    coord.join(threads)


 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值