制作tfrecord 的关键代码
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[img_and_label[1]])), #输入标签
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) #输入图像
})) # example对象对label和image数据进行封装
writer.write(example.SerializeToString()) # 序列化为字符串
整个代码
# -*- coding: utf-8 -*-
import tensorflow as tf
import os
from PIL import Image
from functions import *
writer = tf.python_io.TFRecordWriter("/home/yuejian/Desktop/1234567.tfrecords") # 要生成的文件
txtlist=[]
findtxt(dir,txtlist)
for i,element in enumerate(txtlist):
img_and_label=find_img_and_label(element) #element /home/yuejian/Desktop/tfrecord/train_data/train_data/img_627.txt
img = Image.open(img_and_label[0])
img = img.resize((128, 128))
img_raw = img.tobytes() # 将图片转化为二进制格式
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[img_and_label[1]])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
})) # example对象对label和image数据进行封装
writer.write(example.SerializeToString()) # 序列化为字符串
writer.close()
functions
import os
import tensorflow as tf
dir="/home/yuejian/Desktop/tfrecord/train_data/train_data"
def findtxt(path, ret):
"""Finding the *.txt file in specify path"""
filelist = os.listdir(path)
for filename in filelist:
de_path = os.path.join(path, filename)
if os.path.isfile(de_path):
if de_path.endswith(".txt"): # Specify to find the txt file.
ret.append(de_path)
else:
findtxt(de_path, ret)
def find_img_and_label(txt_path):
f=open(txt_path)
txt=f.read()
end=0
for i,char in enumerate(txt):
if char==',':
end=i
path=''
for i in range(0,end):
path=path+txt[i]
f.close()
img_path=os.path.join(dir,str(path))
begin=0
for i,char in enumerate(txt):
if char==' ':
begin=i
label_str=''
for i in range(begin+1,len(txt)):
label_str=label_str+txt[i]
label=int(label_str)
return(img_path,label)