import os
import tensorflow.compat.v1 as tf
# tf.disable_v2_behavior()
tf.disable_eager_execution()
# import tensorflow as tf
class Cifar(object):
def __init__(self):
# 初始化操作
self.height = 32
self.width = 32
self.channels = 3
#字节数
self.image_bytes = self.width * self.height * self.channels
self.label_bytes = 1
self.all_bytes = self.label_bytes + self.image_bytes
def read_and_decode(self, file_list):
# 构建文件名队列
file_queue = tf.train.string_input_producer(file_list)
# 2 读取与解码
# 读取阶段
reader = tf.FixedLengthRecordReader(self.all_bytes)
# key 返回文件名, value 返回样本
key, value = reader.read(file_queue)
print("key:", key)
print("value:", value)
# 解码阶段
# 二进制文件的解码
decoded = tf.decode_raw(value, tf.uint8)
print("decoded:\n", decoded)
# 将目标值与特征值切片分开
lable = tf.slice(decoded, [0], [self.label_bytes])
image = tf.slice(decoded, [self.label_bytes], [self.image_bytes])
print("label:\n", lable)
print("image:\n", image)
# image = tf.image.decode_jpeg(value)
# print("image;", image)
#
# # 图像的形状、类型修改
# image_resize = tf.image.resize_images(image, [200, 200])
# print("image_resize:", image_resize)
#
# 对二进制图像的形状进行修改,人家是3,但是前期定为1了
image_reshaped = tf.reshape(image, shape=[self.channels, self.height, self.width])
print("image_reshaped:\n", image_reshaped)
# # 静态形状修改
# image_resize.set_shape(shape=[200, 200, 3])
#
# 将图片的 顺序转成height, width, channels
image_transpose = tf.transpose(image_reshaped, [1,2,0])
print("image_transpose:\n", image_transpose)
# 调整图像类型
image_cast = tf.cast(image_transpose, tf.float32)
# # 3、批处理
# image_batch = tf.train.batch([image_resize], batch_size=100, num_threads=1, capacity=100)
lable_batch, image_batch = tf.train.batch([lable, image_cast], batch_size=100, num_threads=1, capacity=100)
print("lable_batch",lable_batch)
print("image_batch",image_batch)
# print("image_batch:", image_batch)
# 开启会话
with tf.Session() as sess:
# 创建线程协调员
coord = tf.train.Coordinator()
# 开启线程
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# key_new, value_new, image_new, image_resize_new, image_batch_new = \
# sess.run([key, value, image, image_resize, image_batch])
# key_new, value_new, decoded_new, lable_new, image_new, image_reshaped_new, image_transpose_new, image_cast_new, lable_batch_new, image_batch_new \
# = sess.run([key, value, decoded, lable, image, image_reshaped, image_transpose, image_cast, lable_batch, image_batch ])
# print("key_new:", key_new)
# print("value_new:", value_new)
# print("decoded_new:", decoded_new)
# print("image_new:", image_new)
# print("image_reshaped_new:", image_reshaped_new)
# print("image_transpose_new:", image_transpose_new)
# print("image_cast:", image_cast_new)
# print("image_batch_cast:", image_batch_new)
# print("lable_new:", lable_new)
# print("lable_batch_new:", lable_batch_new)
label_value, image_value = sess.run([lable_batch, image_batch])
# 回收线程
coord.request_stop()
coord.join(threads)
return label_value, image_value
def write_to_tfrecords(self, image_batch, label_batch):
# 将样本的特征值和目标值一起写入records 文件
with tf.python_io.TFRecordWriter("cifar10.tfrecords") as writer: # 上下文管理器
for i in range(100): # 循环构造100个example 对象,并且序列化写到文件
image = image_batch[i].tostring()
label = label_batch[i][0]
print("tfrecords_image:\n", image)
print("tfrecords_label:\n", label)
example = tf.train.Example(features = tf.train.Features(feature = {
"example": tf.train.Feature(bytes_list = tf.train.BytesList(value = [image])),
"label": tf.train.Feature(int64_list = tf.train.Int64List(value=[label])),
}))
# print(example)
# a将序列化的example 写到文件
writer.write(example.SerializeToString())
return None
# 读取record中的数据 ,步骤:(1)构造文件名队列,(2)读取与解码 读取 解析example 解码(3)构造批处理队列
def read_tfrecords(self):
# 1 构造文件名队列
file_queue = tf.train.string_input_producer(["cifar10.tfrecords"])
# 2 读取与解码
# 读取
reader = tf.TFRecordReader()
key, value = reader.read(file_queue)
feature = tf.parse_single_example(value, features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64)
})
image = feature["image"]
label = feature["label"]
print("read_tf_image:\n", image)
print("read_tf_label:\n", label)
# 解析example
# jiema
image_decoded = tf.decode_raw(image, tf.unit8)
# 图像形状调整
image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channels])
print("image_reshped:\n", image_reshaped)
# 3 构造批处理队列
image_batch, label_batch = tf.train.batch([image_reshaped,label], batch_size=100, num_threads = 2, capacity = 100)
print("image_batch:\n",image_batch)
print("label_batch:\n",label_batch)
# 开启会话
with tf.Session() as sess:
# 开启线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord = coord)
image_value, label_value = sess.run([image_decoded, label])
# image_value, label_value = sess.run([image_decoded, label])
# print("image_value:\n", image_value)
# print("label_value:\n", label_value)
coord.request_stop()
coord.join(threads)
return None
def picture_read(file_list):
"""
狗图片读取案例
:return:
"""
# 1、构造文件名队列
file_queue = tf.train.string_input_producer(file_list)
# 2、读取与解码
# 读取阶段
reader = tf.WholeFileReader()
# key文件名, value一张图片的原始编码形式
key, value = reader.read(file_queue)
print("key:", key)
print("value:", value)
# 解码阶段
image = tf.image.decode_jpeg(value)
print("image;", image)
# 图像的形状、类型修改
image_resize = tf.image.resize_images(image, [200,200])
print("image_resize:", image_resize)
# 静态形状修改
image_resize.set_shape(shape=[200,200,3])
# 3、批处理
image_batch = tf.train.batch([image_resize], batch_size=100, num_threads=1, capacity=100)
print("image_batch:", image_batch)
# 开启会话
with tf.Session() as sess:
# 创建线程协调员
coord = tf.train.Coordinator()
# 开启线程
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
key_new, value_new, image_new, image_resize_new, image_batch_new = \
sess.run([key, value, image, image_resize, image_batch])
print("key_new:", key_new)
print("value_new:", value_new)
print("image_new:", image_new)
print("image_resize_new:", image_resize_new)
print("image_batch_new:", image_batch_new)
# 回收线程
coord.request_stop()
coord.join(threads)
return None
if __name__ == "__main__":
# 构造路径 + 文件名的列表
filename = os.path.dirname(__file__)
filelist = os.listdir(filename)
# 获取当前文件目录
print(filename)
print(filelist)
# 拼接路径 + 文件名
# for file in filelist :
# if file[-3:] == "bin" :
# file_appdix = file
# path = os.path.join(filename, file_appdix)
file_list = [os.path.join(filename, file) for file in filelist if file[-3:] == "bin"]
print(file_list)
# picture_read(file_list)
Cifar = Cifar()
# label_value, image_value = Cifar.read_and_decode(file_list)
# Cifar.write_to_tfrecords(image_value, label_value)
Cifar.read_tfrecords()
补充读二进制文件
最新推荐文章于 2024-07-19 16:36:18 发布