# 基本操作,先导入要使用的工具
import tensorflow as tf
import numpy as np
from PIL import Image
import os,glob
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 导入需要用到的包之后,启用一下GPU
第一步,构造需要用到的函数
#制作Featuer的数据,类型包括了Float,Int64和Bytes
def Float_Feature(value):
return tf.train.Feature(float_list = tf.train.FloatList(value = [value]))
def Int64_Feature(value):
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
def Byte_Feature(value):
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
第二步,读取图片路径并制作数据集
def loadAndMakeTFRecord(path):
pictures = glob.glob(path)
print("已准备好写入器!")
tfrecordWriter = tf.io.TFRecordWriter("./K.tfrecord")
for single_pic in pictures:
print(single_pic)
with open(single_pic,'rb') as f:
binary_pic = f.read()
pic_for_shape = tf.io.read_file(single_pic)
picShape = tf.image.decode_jpeg(pic_for_shape,channels=3).shape
examples = tf.train.Example(
features = tf.train.Features(
feature = {
'width':Int64_Feature(picShape[0]),
'height':Int64_Feature(picShape[1]),
'mode':Int64_Feature(picShape[2]),
'raw_image':Byte_Feature(binary_pic) #图片必须二进制写入,以字符串形式读出
}
)
)
print("开始写入!")
tfrecordWriter.write(examples.SerializeToString())
tfrecordWriter.close()
print("数据集写入完成!")
第三步,解析数据集
def parseDataSets(tfrecordData):
feature = {
'width':tf.io.FixedLenFeature([],tf.int64),
'height':tf.io.FixedLenFeature([],tf.int64),
'mode':tf.io.FixedLenFeature([],tf.int64),
'raw_image':tf.io.FixedLenFeature([],tf.string)
}
single_example = tf.io.parse_single_example(tfrecordData,feature)
raw_image = tf.image.decode_jpeg(single_example['raw_image'],channels=3) #图片必须二进制写入,以字符串形式读出,是因为decode_jpeg需要“A Tensor of String”
raw_image_tensor = tf.image.resize(raw_image,(64,64))
return raw_image_tensor
def showImage(data):
new_img = Image.new('RGB',(10*64,10*64))
x = 0
y = 0
for singleBatch in data:
for singlePic in singleBatch:
images = Image.fromarray(np.uint8(singlePic),'RGB')
new_img.paste(images,(x,y))
x += 64
if x >= 10*64:
x = 0
y += 64
new_img.show()
第四步,测试程序
#主程序入口
if __name__ == '__main__':
pictures_path = r"E:\data\single\*.jpg"
repeateTime = 2
loadAndMakeTFRecord(pictures_path)
tfData = tf.data.TFRecordDataset("./K.tfrecord") #读取已经制作好的数据集
tfData = tfData.map(parseDataSets).shuffle(100).batch(3,drop_remainder=True).repeat(repeateTime)
for _ in range(repeateTime):
showImage(tfData)