Tensorflow基础:TFRecords的扩展应用-将图片存储到TFRecords文件中,并从中读取

这是对上一篇文章的应用,内容相似,不多说了,直接上代码。
一、存储
图片文件存储结构:
在这里插入图片描述

#!/usr/bin/env python 
# -*- coding:utf-8 -*-
# author:Dr.Shang

"""
使用每个图片的文件名作为标签
"""

import os
import tensorflow as tf
from PIL import Image


path = 'jpg'
filenames = os.listdir(path)  # 返回path路径下的文件夹名称列表
writer = tf.python_io.TFRecordWriter('TFRecords/train.tfrecords')

for label, name in enumerate(os.listdir(path), start=1):
    class_path = path + os.sep + name # os.sep==》/
    for img_name in os.listdir(class_path):
        img_path = class_path + os.sep + img_name
        img = Image.open(img_path)
        img = img.resize((500, 500))
        img_raw = img.tobytes()
        example = tf.train.Example(features=tf.train.Features(feature={
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
            'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))
        writer.write(example.SerializeToString())
        print('类别{}:文件{}存储成功'.format(name, img_name))


二、读取

#!/usr/bin/env python 
# -*- coding:utf-8 -*-
# author:Dr.Shang


import tensorflow as tf
import cv2


filename = 'TFRecords/train.tfrecords'
filename_queue = tf.train.string_input_producer([filename])

reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) # 返回文件名和文件
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'image': tf.FixedLenFeature([], tf.string)
                                   })

img = tf.decode_raw(features['image'], tf.uint8)
img = tf.reshape(img, [500, 500, 3])

img = tf.cast(img, tf.float32) * (1. / 500)
label = tf.cast(features['label'], tf.int32)

img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=1, capacity=10, min_after_dequeue=6)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

for _ in range(2):
    img = sess.run(img_batch)
    label = sess.run(label_batch)
    img.resize((500, 500, 3)) # 没有这行报错
    cv2.imshow('test', img)
    cv2.waitKey()
    print(label)



  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值