有了jpg读取的经验和tfrecord写入的经验之后,开始尝试把jpg图像写入到tfrecord,另外还想尝试从tfrecord文件读出jpg图像
上示例把jpg的图片的二进制以及长和宽的信息保存进tfrecord
decode_jpeg_data = tf.placeholder(dtype=tf.string)
decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data, channels=3)
tfrecords_filename = './tfrecords'
writer = tf.python_io.TFRecordWriter(tfrecords_filename) # 创建.tfrecord文件,准备写入
image_data = tf.gfile.FastGFile("C:/Users/shenwei/Desktop/timg.jpg", 'rb').read()
print(type(image_data))
with tf.Session() as sess:
image = sess.run(decode_jpeg,feed_dict={decode_jpeg_data: image_data})
print(image.shape[0])
print(image.shape[1])
example = tf.train.Example(features=tf.train.Features(
feature={
'encoded': tf.train.Feature(bytes_list = tf.train.BytesList(value=[image_data])) ,
'height': tf.train.Feature(int64_list = tf.train.Int64List(value = [image.shape[0]])),
'width': tf.train.Feature(int64_list = tf.train.Int64List(value = [image.shape[1]])),
}))
writer.write(example.SerializeToString())
writer.close()
********************上一段完整的保存以及读取的示例********************
参考https://blog.csdn.net/MOU_IT/article/details/82773839
slim = tf.contrib.slim
# 创建TFrecord文件
def create_record_file():
train_filename = "train.tfrecords"
if os.path.exists(train_filename):
os.remove(train_filename)
# 创建.tfrecord文件,准备写入
writer = tf.python_io.TFRecordWriter('./'+train_filename)
with tf.Session() as sess:
for i in range(10):
img_raw = tf.gfile.FastGFile("C:/Users/shenwei/Desktop/test/"+str(i)+".jpg", 'rb').read()
decode_data = tf.image.decode_jpeg(img_raw)
image_shape= decode_data.eval().shape
example = tf.train.Example(features=tf.train.Features(
feature={
'image/encoded':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'image/format':tf.train.Feature(bytes_list = tf.train.BytesList(value=[b'jpg'])),
'image/width':tf.train.Feature(int64_list = tf.train.Int64List(value=[image_shape[1]])),
'image/height':tf.train.Feature(int64_list = tf.train.Int64List(value=[image_shape[0]])),
'image/label':tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),
}))
writer.write(example.SerializeToString()) # 序列化保存
writer.close()
print ("保存tfrecord文件成功。")
# 使用Slim的方法从TFrecord文件中读取
def read_record_file():
tfrecords_filename = "train.tfrecords"
# 将tf.train.Example反序列化成存储之前的格式。由tf完成
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/width': tf.FixedLenFeature((), tf.int64, default_value=0),
'image/height': tf.FixedLenFeature((), tf.int64, default_value=0),
'image/label': tf.FixedLenFeature((), tf.int64, default_value=0),
}
# 将反序列化的数据组装成更高级的格式。由slim完成
items_to_handlers = {
'image': slim.tfexample_decoder.Image(image_key='image/encoded',
format_key='image/format',
channels=3),
'label': slim.tfexample_decoder.Tensor('image/label'),
'height': slim.tfexample_decoder.Tensor('image/height'),
'width': slim.tfexample_decoder.Tensor('image/width')
}
# 定义解码器,进行解码
decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
# 定义dataset,该对象定义了数据集的文件位置,解码方式等元信息
dataset = slim.dataset.Dataset(
data_sources=tfrecords_filename,
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=10, # 训练数据的总数
items_to_descriptions=None,
num_classes=10,
)
#使用provider对象根据dataset信息读取数据
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=1,
common_queue_capacity=20,
common_queue_min=1)
# 获取数据
[image, label,height,width] = provider.get(['image', 'label','height','width'])
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
coord=tf.train.Coordinator()
threads= tf.train.start_queue_runners(coord=coord)
for i in range(10):
img,l,h,w= sess.run([image,label,height,width])
#img = tf.reshape(img, [h,w,3])
print (img.shape)
#img=Image.fromarray(img.eval(), 'RGB') # 这里将narray转为Image类,Image转narray:a=np.array(img)
#img.save('./'+str(l)+'.jpg') # 保存图片
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
#create_record_file()
read_record_file()