一、将生成record文件,以图片为例
#!/usr/bin/env python
# -*- coding:utf-8 -*-
#Author: 1477517404@qq.com
import tensorflow as tf
from PIL import Image
import os
import io
def int64_feature(value):
return tf.train.Feature(int64_list = tf.train.Int64List(value=[value]))
def bytes_feature(value):
return tf.train.Feature(bytes_list = tf.train.BytesList(value=[value]))
def process_image_channels(image):
process_flag = False
print(image.mode)
if image.mode == 'RGBA':
r,g,b,a = image.split()
image = Image.merge("RGB",(r,g,b))
process_flag = True
elif image.mode != 'RGB':
image = image.convert('RGB')
process_flag = True
print('process_flag is : ',process_flag)
return image,process_flag
def create_tf_example(image_path,label,resize=None):
with tf.io.gfile.GFile(image_path,'rb') as fid:
encode_jpg = fid.read()
encode_jpg_io = io.BytesIO(encode_jpg)
image = Image.open(encode_jpg_io)
image,flag = process_image_channels(image)
# if flag == True:
bytes_io = io.BytesIO()
image.save(bytes_io,format='JPEG')
encode_jpg = bytes_io.getvalue()
print(len(encode_jpg))
width,height = image.size
tf_example = tf.train.Example(
features = tf.train.Features(
feature = {
"image/encoded":bytes_feature(encode_jpg),
'image/format':bytes_feature(b'jpg'),
'image/class/label':int64_feature(1),
'image/height':int64_feature(height),
'image/width':int64_feature(width)
}
)
)
return tf_example
def generate_tfrecord(annotation_dict,record_path,resize=None):
num_tf_example = 0
writer = tf.io.TFRecordWriter(record_path)
for image_path,lable in annotation_dict.items():
if not tf.gfile.GFile(image_path):
print("{} does not exist".format(image_path))
tf_example = create_tf_example(image_path,lable,resize)
writer.write(tf_example.SerializeToString())
num_tf_example += 1
if num_tf_example % 2 == 0:
print("Create %d TF_example" % num_tf_example)
writer.close()
def get_annotation_dict(image_dir):
annotation_dict = {}
filelist = os.listdir(image_dir)
for image in filelist:
annotation_dict[image_dir+image] = 1 # 方便起见
return annotation_dict
def main():
image_dir = 'data/'
record_path = 'data/image.record'
annotation_dict = get_annotation_dict(image_dir)
generate_tfrecord(annotation_dict,record_path)
if __name__ == "__main__":
main()
二、tf.data 读取生成的文件
#!/usr/bin/env python
# -*- coding:utf-8 -*-
#Author:1477517404@qq.com
import tensorflow as tf
import multiprocessing as mt
from PIL import Image
import matplotlib.pyplot as plt
def parser(record):
features = {
'image/encoded': tf.FixedLenFeature((), default_value='', dtype=tf.string),
'image/format': tf.FixedLenFeature((), default_value='jpg', dtype=tf.string),
'image/class/label': tf.FixedLenFeature([], default_value=0, dtype=tf.int64),
'image/height': tf.FixedLenFeature([], default_value=0, dtype=tf.int64),
'image/width': tf.FixedLenFeature([], default_value=0, dtype=tf.int64)
}
example = tf.parse_single_example(record,features)
width = example['image/width']
image = tf.reshape(tf.image.decode_jpeg(example['image/encoded']),(width,width,3))
label = example['image/class/label']
return image,label
def read_tfRecord():
dataset = tf.data.TFRecordDataset(['data/image.record']) # 读取record文件
dataset = dataset.map(parser) # 使用解析方法对数据进行解析
dataset = dataset.shuffle(buffer_size=100).batch(5) # 打乱 buffer_size 要大于样本数才能保证充分打乱,获取批次
dataset = dataset.repeat(5) # 将数据重复 num_epoches
return dataset
def main():
dataset = read_tfRecord()
print('shapes:', dataset.output_shapes)
print('types:', dataset.output_types)
next_op = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
for i in range(40):
print('--------------------------batch({})---------------------'.format(i))
try:
batch_image, batch_label = sess.run(next_op)
image = batch_image[0]
# 显示一下图片查看是否有问题
image = Image.fromarray(image)
plt.figure()
plt.imshow(image)
plt.show()
print(batch_image.shape)
print(batch_label,batch_label.shape,batch_label[0])
except tf.errors.OutOfRangeError:
print("队列已经遍历完成!")
break
if __name__ == '__main__':
main()
程序结果: