测试数据:
链接:cifar-10-binary.tar.gz
TensorFlow 读取二进制文件 写入tfrecords文件 tf1.0
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
from tensorflow import keras
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:
print(module.__name__, module.__version__)
def read_and_decode(file_list):
"""读取二进制文件"""
height = 32
width = 32
channel = 3
label_bytes = 1
image_bytes = height * width * channel
bytes_count = image_bytes + label_bytes
# 1.构造文件队列
file_queue = tf.train.string_input_producer(file_list)
# 2.构造二进制文件读取器 每个样本的字节数
reader = tf.FixedLengthRecordReader(bytes_count)
key, value = reader.read(file_queue)
# 3.解码内容, 二进制文件内容的解码
label_image = tf.decode_raw(value, tf.uint8)
print(label_image)
# 4.分割出图片和标签数据 切出特征值的目标值
label = tf.cast(tf.slice(label_image, [0], [label_bytes]), tf.int32)
image = tf.slice(label_image, [label_bytes], [image_bytes])
# 5.可以对图片的特征数据进行形状改变 [3072] -->[32, 32, 3]
image_reshape = tf.reshape(image, [height, width, channel])
print(label, image_reshape)
# 6.批处理数据
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
print(image_batch, label_batch)
return image_batch, label_batch
def write_to_tfrecords(image_batch, label_batch):
"""
将图片的特征值和目标值存入tfrecords
image_batch: 10张图片的特征值
label_batch: 10张图片的目标值
"""
# 1.构造一个tfrecords 文件 tfrecords存储器
# 确认文件夹是否存在
output_dir = "./cifar10_tfrecords"
if not os.path.exists(output_dir):
os.mkdir(output_dir)
path = "./cifar10_tfrecords/cifar_1.tfrecords" # 存入的文件
writer = tf.python_io.TFRecordWriter(path)
# 2.循环将所有样本写入文件, 每张图片样本都要构造example协议
for i in range(10):
# 取出第i个图片的特征值和目标值
image = image_batch[i].eval().tostring()
label = label_batch[i].eval()[0]
# 构造一个样本的example协议
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
}))
# 写入单独的样本
writer.write(example.SerializeToString())
# 关闭
writer.close()
return None
if __name__ == "__main__":
# 找到文件,放入列表 路径+名字 ->列表当中
file_names = os.listdir("./cifar10/cifar-10-batches-bin/")
file_list = [os.path.join("./cifar10/cifar-10-batches-bin/", file)
for file in file_names if file[-3:] == "bin"]
print(file_list)
# 读取二进制文件
image_batch, label_batch = read_and_decode(file_list)
# 开启会话运行结果
with tf.Session() as sess:
# 定义一个线程协调
coord = tf.train.Coordinator()
# 开启读文件的线程
threads = tf.train.start_queue_runners(sess, coord)
# 显示二进制文件
print(sess.run([image_batch, label_batch]))
# 将二进制文件写入tfrecords
print("开始存储")
write_to_tfrecords(image_batch, label_batch )
print("结束存储")
# 回收子线程
coord.request_stop()
coord.join(threads)
TensorFlow 读取二进制文件 写入tfrecords文件 tf2.0
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras
print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:
print(module.__name__, module.__version__)
# 读取二进制文件
def parse_bin_line(line, image_bytes_length=3072):
"""
对每行内容解码
:param image_bytes_length: 图片的长度
:param line:每行数据
return: x 特征值列表, y 目标值
"""
# 1.对每行数据解码
label_image = tf.io.decode_raw(line, tf.uint8)
# 2.切分数据
x = tf.slice(label_image, [1], [image_bytes_length]) # 图片
y = tf.cast(tf.slice(label_image, [0], [1]), tf.int32) # 标签
return x, y
def write_to_tfrecords(cifar_data, file_path):
"""写入tfrecords文件"""
with tf.io.TFRecordWriter(file_path) as writer:
for x_batch, y_batch in cifar_data.take(10):
image = x_batch.numpy().tostring() # 图片
label = y_batch.numpy()[0] # 标签
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
}))
serialized_example = example.SerializeToString()
writer.write(serialized_example)
if __name__ == "__main__":
# 1.获取二进制文件列表
file_names = os.listdir("./cifar10/cifar-10-batches-bin/")
file_list = [os.path.join("./cifar10/cifar-10-batches-bin/", file)
for file in file_names if file[-3:] == "bin"]
print(file_list)
# 2.每条数据长度
image_bytes = 32 * 32 * 3
record_bytes = 1 + image_bytes
# 3.读取文件列表
cifar_data = tf.data.FixedLengthRecordDataset(file_list, record_bytes)
# 4.映射解析每一条文件
cifar_data = cifar_data.map(parse_bin_line, num_parallel_calls=3)
# 5.设置读图片个数
cifar_data = cifar_data.batch(1)
# 6.显示dataset内容
for x_batch, y_batch in cifar_data.take(2):
print("x:")
print(x_batch)
print("y:")
print(y_batch)
# 文件路径
output_dir = "./cifar10_tfrecords"
if not os.path.exists(output_dir):
os.mkdir(output_dir)
# 7.写入tfrecords文件
path = "./cifar10_tfrecords/cifar_2.tfrecords"
write_to_tfrecords(cifar_data, path)