#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 29 11:39:25 2020
@author: gaofeng
"""
import tensorflow as tf
import os
import cv2
def write(input_file, output_file):
writer = tf.io.TFRecordWriter(output_file) #定义writer,传入目标文件路径
path = input_file
file_names = [f for f in os.listdir(path) if f.endswith('.png')] #获取待存文件路径
i=1
for file_name in file_names:
img = cv2.imread(path + file_name)
raw_img = img.tobytes() #需要把图片文件转化成bytes形式(二进制比特流)
file_name=bytes(file_name, encoding='utf-8')
# 把数据合并成feature,注意这里的"value="后面一定要是一个"[]"形式的列表,否则读取的时候会出现can't parse的情况
features = tf.train.Features(feature={'img_name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[file_name])),
'raw_img': tf.train.Feature(bytes_list=tf.train.BytesList(value=[raw_img]))})
#把features存入example
example = tf.train.Example(features=features)
#example序列化,并写入文件
writer.write(example.SerializeToString())
i=i+1
print(i)
writer.close()
input_file = './image/'
output_file = 'samples.tfrecords'
#writer= tf.python_io.TFRecordWriter('./image/'+ 'laneimage.tfrecords')
write(input_file, output_file)
print('Write tfrecords: %s done' %output_file)
制作tfrecord
最新推荐文章于 2021-05-31 21:01:46 发布