TensorFlow------读取二进制文件实例

TensorFlow------读取二进制文件实例:

 

class CifarRead(object):
    '''
    完成读取二进制文件,写进tfrecords,读取tfrecords
    :param object:
    :return:
    '''
    def __init__(self,filelist):
        # 文件列表
        self.file_list = filelist

        # 定义读取的图片的一些属性
        self.height = 32
        self.width = 32
        self.channel = 3
        # 二进制文件每张图片的字节
        self.label_bytes = 1
        self.image_bytes = self.height * self.width * self.channel
        self.bytes = self.label_bytes + self.image_bytes

    def read_and_decode(self):
        # 1. 构建文件队列
        file_queue = tf.train.string_input_producer(self.file_list)

        # 2. 构建二进制文件读取器,读取内容,每个样本的字节数
        reader = tf.FixedLengthRecordReader(self.bytes)

        key,value = reader.read(file_queue)

        # 3. 解码内容,二进制文件内容的解码 label_image包含目标值和特征值
        label_image = tf.decode_raw(value,tf.uint8)
        print(label_image)

        # 4.分割出图片和标签数据,特征值和目标值
        label = tf.slice(label_image,[0],[self.label_bytes])

        image = tf.slice(label_image,[self.label_bytes],[self.image_bytes])
        print('---->')
        print(image)

        # 5. 可以对图片的特征数据进行形状的改变 [3072]-->[32,32,3]
        image_reshape = tf.reshape(image,[self.height,self.width,self.channel])

        print('======>')
        print(label)
        print('======>')

        # 6. 批处理数据
        image_batch,label_batch = tf.train.batch([image_reshape,label],batch_size=10,num_threads=1,capacity=10)

        print(image_batch,label_batch)

        return image_batch,label_batch


if __name__ == '__main__':
    # 找到文件,构建列表  路径+名字  ->列表当中
    file_name = os.listdir(FLAGS.cifar_dir)

    # 拼接路径 重新组成列表
    filelist = [os.path.join(FLAGS.cifar_dir,file) for file in file_name if file[-3:] == 'bin']

    # 调用函数传参
    cf = CifarRead(filelist)
    image_batch,label_batch = cf.read_and_decode()

    # 开启会话
    with tf.Session() as sess:
        # 定义一个线程协调器
        coord = tf.train.Coordinator()

        # 开启读文件的线程
        threads = tf.train.start_queue_runners(sess,coord=coord)

        # 打印读取的内容
        print(sess.run([image_batch,label_batch]))

        # 回收子线程
        coord.request_stop()

        coord.join(threads)

 

转载于:https://www.cnblogs.com/fwl8888/p/9762466.html

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值