TensorFlow 读取二进制文件

测试数据:
链接:cifar-10-binary.tar.gz

TensorFlow 读取二进制文件 写入tfrecords文件 tf1.0

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time

from tensorflow import keras

import tensorflow.compat.v1 as tf

tf.disable_v2_behavior()

print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:
    print(module.__name__, module.__version__)


def read_and_decode(file_list):
    """读取二进制文件"""
    height = 32
    width = 32
    channel = 3

    label_bytes = 1
    image_bytes = height * width * channel
    bytes_count = image_bytes + label_bytes

    # 1.构造文件队列
    file_queue = tf.train.string_input_producer(file_list)
    # 2.构造二进制文件读取器 每个样本的字节数
    reader = tf.FixedLengthRecordReader(bytes_count)

    key, value = reader.read(file_queue)
    # 3.解码内容, 二进制文件内容的解码
    label_image = tf.decode_raw(value, tf.uint8)

    print(label_image)

    # 4.分割出图片和标签数据 切出特征值的目标值
    label = tf.cast(tf.slice(label_image, [0], [label_bytes]), tf.int32)
    image = tf.slice(label_image, [label_bytes], [image_bytes])

    # 5.可以对图片的特征数据进行形状改变 [3072] -->[32, 32, 3]
    image_reshape = tf.reshape(image, [height, width, channel])
    print(label, image_reshape)

    # 6.批处理数据
    image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
    print(image_batch, label_batch)

    return image_batch, label_batch


def write_to_tfrecords(image_batch, label_batch):
    """
    将图片的特征值和目标值存入tfrecords
    image_batch: 10张图片的特征值
    label_batch: 10张图片的目标值
    """
    # 1.构造一个tfrecords 文件 tfrecords存储器
    # 确认文件夹是否存在

    output_dir = "./cifar10_tfrecords"
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    path = "./cifar10_tfrecords/cifar_1.tfrecords"  # 存入的文件
    writer = tf.python_io.TFRecordWriter(path)

    # 2.循环将所有样本写入文件, 每张图片样本都要构造example协议
    for i in range(10):
        # 取出第i个图片的特征值和目标值
        image = image_batch[i].eval().tostring()
        label = label_batch[i].eval()[0]

        # 构造一个样本的example协议
        example = tf.train.Example(features=tf.train.Features(feature={
            "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
        }))

        # 写入单独的样本
        writer.write(example.SerializeToString())

    # 关闭
    writer.close()

    return None


if __name__ == "__main__":

    # 找到文件,放入列表  路径+名字 ->列表当中
    file_names = os.listdir("./cifar10/cifar-10-batches-bin/")
    file_list = [os.path.join("./cifar10/cifar-10-batches-bin/", file)
                 for file in file_names if file[-3:] == "bin"]
    print(file_list)

    # 读取二进制文件
    image_batch, label_batch = read_and_decode(file_list)

    # 开启会话运行结果
    with tf.Session() as sess:
        # 定义一个线程协调
        coord = tf.train.Coordinator()
        # 开启读文件的线程
        threads = tf.train.start_queue_runners(sess, coord)

        # 显示二进制文件
        print(sess.run([image_batch, label_batch]))

        # 将二进制文件写入tfrecords
        print("开始存储")
        write_to_tfrecords(image_batch, label_batch )
        print("结束存储")

        # 回收子线程
        coord.request_stop()
        coord.join(threads)

TensorFlow 读取二进制文件 写入tfrecords文件 tf2.0

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf

from tensorflow import keras

print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:
    print(module.__name__, module.__version__)


# 读取二进制文件
def parse_bin_line(line, image_bytes_length=3072):
    """
    对每行内容解码
    :param image_bytes_length: 图片的长度
    :param line:每行数据
    return: x 特征值列表, y 目标值
    """
    # 1.对每行数据解码
    label_image = tf.io.decode_raw(line, tf.uint8)

    # 2.切分数据
    x = tf.slice(label_image, [1], [image_bytes_length])  # 图片

    y = tf.cast(tf.slice(label_image, [0], [1]), tf.int32)  # 标签

    return x, y


def write_to_tfrecords(cifar_data, file_path):
    """写入tfrecords文件"""

    with tf.io.TFRecordWriter(file_path) as writer:

        for x_batch, y_batch in cifar_data.take(10):

            image = x_batch.numpy().tostring()  # 图片
            label = y_batch.numpy()[0]          # 标签

            example = tf.train.Example(features=tf.train.Features(feature={
                "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
            }))

            serialized_example = example.SerializeToString()

            writer.write(serialized_example)


if __name__ == "__main__":

    # 1.获取二进制文件列表
    file_names = os.listdir("./cifar10/cifar-10-batches-bin/")
    file_list = [os.path.join("./cifar10/cifar-10-batches-bin/", file)
                 for file in file_names if file[-3:] == "bin"]
    print(file_list)

    # 2.每条数据长度
    image_bytes = 32 * 32 * 3
    record_bytes = 1 + image_bytes

    # 3.读取文件列表
    cifar_data = tf.data.FixedLengthRecordDataset(file_list, record_bytes)

    # 4.映射解析每一条文件
    cifar_data = cifar_data.map(parse_bin_line, num_parallel_calls=3)

    # 5.设置读图片个数
    cifar_data = cifar_data.batch(1)

    # 6.显示dataset内容
    for x_batch, y_batch in cifar_data.take(2):
        print("x:")
        print(x_batch)
        print("y:")
        print(y_batch)

    # 文件路径
    output_dir = "./cifar10_tfrecords"
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    # 7.写入tfrecords文件
    path = "./cifar10_tfrecords/cifar_2.tfrecords"
    write_to_tfrecords(cifar_data, path)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

廷益--飞鸟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值