在tensorflow 2 中制作和使用 tfrecord数据集

# 基本操作,先导入要使用的工具

import tensorflow as tf
import numpy as np
from PIL import Image
import os,glob

os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 导入需要用到的包之后,启用一下GPU

第一步,构造需要用到的函数

#制作Featuer的数据,类型包括了Float,Int64和Bytes

def Float_Feature(value):
    return tf.train.Feature(float_list = tf.train.FloatList(value = [value]))

def Int64_Feature(value):
    return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))

def Byte_Feature(value):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))

第二步,读取图片路径并制作数据集

def loadAndMakeTFRecord(path):
    pictures = glob.glob(path)
    print("已准备好写入器!")
    tfrecordWriter = tf.io.TFRecordWriter("./K.tfrecord")

    for single_pic in pictures:
        print(single_pic)
        with open(single_pic,'rb') as f:
            binary_pic = f.read()
            pic_for_shape = tf.io.read_file(single_pic)
            picShape = tf.image.decode_jpeg(pic_for_shape,channels=3).shape

            examples = tf.train.Example(

                features = tf.train.Features(

                    feature = {

                        'width':Int64_Feature(picShape[0]),
                        'height':Int64_Feature(picShape[1]),
                        'mode':Int64_Feature(picShape[2]),
                        'raw_image':Byte_Feature(binary_pic)  #图片必须二进制写入,以字符串形式读出

                    }

                )

            )
            print("开始写入!")
            tfrecordWriter.write(examples.SerializeToString())
    tfrecordWriter.close()
    print("数据集写入完成!")

第三步,解析数据集

def parseDataSets(tfrecordData):

    feature = {
                'width':tf.io.FixedLenFeature([],tf.int64),
                'height':tf.io.FixedLenFeature([],tf.int64),
                'mode':tf.io.FixedLenFeature([],tf.int64),
                'raw_image':tf.io.FixedLenFeature([],tf.string)
    }

    single_example = tf.io.parse_single_example(tfrecordData,feature)
    raw_image = tf.image.decode_jpeg(single_example['raw_image'],channels=3)  #图片必须二进制写入,以字符串形式读出,是因为decode_jpeg需要“A Tensor of String”
    raw_image_tensor = tf.image.resize(raw_image,(64,64))
    return raw_image_tensor



def showImage(data):
    new_img = Image.new('RGB',(10*64,10*64))
    x = 0
    y = 0
    for singleBatch in data:
        for singlePic in singleBatch:
            images = Image.fromarray(np.uint8(singlePic),'RGB')
            new_img.paste(images,(x,y))
            x += 64
            if x >= 10*64:
                x = 0
                y += 64

    new_img.show()

第四步,测试程序

#主程序入口
if __name__ == '__main__':
    pictures_path = r"E:\data\single\*.jpg"
    repeateTime = 2

    loadAndMakeTFRecord(pictures_path)
    tfData = tf.data.TFRecordDataset("./K.tfrecord") #读取已经制作好的数据集

    tfData = tfData.map(parseDataSets).shuffle(100).batch(3,drop_remainder=True).repeat(repeateTime)

    for _ in range(repeateTime):
        showImage(tfData)

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
将VOC2007数据集转换为TFRecord文件需要以下步骤: 1. 下载VOC2007数据集并解压缩。 2. 安装TensorFlow和Pillow库。 3. 编写脚本将VOC2007数据集转换为TFRecord文件。 以下是一个简单的Python脚本示例,可以将VOC2007数据集转换为TFRecord文件: ```python import tensorflow as tf import os import io import xml.etree.ElementTree as ET from PIL import Image def create_tf_example(example): # 读取图像文件 img_path = os.path.join('VOCdevkit/VOC2007/JPEGImages', example['filename']) with tf.io.gfile.GFile(img_path, 'rb') as fid: encoded_jpg = fid.read() encoded_jpg_io = io.BytesIO(encoded_jpg) image = Image.open(encoded_jpg_io) width, height = image.size # 读取标注文件 xml_path = os.path.join('VOCdevkit/VOC2007/Annotations', example['filename'].replace('.jpg', '.xml')) with tf.io.gfile.GFile(xml_path, 'r') as fid: xml_str = fid.read() xml = ET.fromstring(xml_str) # 解析标注文件 xmins = [] xmaxs = [] ymins = [] ymaxs = [] classes_text = [] classes = [] for obj in xml.findall('object'): class_name = obj.find('name').text classes_text.append(class_name.encode('utf8')) classes.append(label_map[class_name]) bbox = obj.find('bndbox') xmins.append(float(bbox.find('xmin').text) / width) ymins.append(float(bbox.find('ymin').text) / height) xmaxs.append(float(bbox.find('xmax').text) / width) ymaxs.append(float(bbox.find('ymax').text) / height) # 构造TFRecord Example tf_example = tf.train.Example(features=tf.train.Features(feature={ 'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])), 'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])), 'image/filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[example['filename'].encode('utf8')])), 'image/source_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[example['filename'].encode('utf8')])), 'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[encoded_jpg])), 'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=['jpeg'.encode('utf8')])), 'image/object/bbox/xmin': tf.train.Feature(float_list=tf.train.FloatList(value=xmins)), 'image/object/bbox/xmax': tf.train.Feature(float_list=tf.train.FloatList(value=xmaxs)), 'image/object/bbox/ymin': tf.train.Feature(float_list=tf.train.FloatList(value=ymins)), 'image/object/bbox/ymax': tf.train.Feature(float_list=tf.train.FloatList(value=ymaxs)), 'image/object/class/text': tf.train.Feature(bytes_list=tf.train.BytesList(value=classes_text)), 'image/object/class/label': tf.train.Feature(int64_list=tf.train.Int64List(value=classes)), })) return tf_example # 将VOC2007数据集转换为TFRecord文件 def create_tf_record(output_file): examples = [...] # 从VOC2007数据集读取实例 writer = tf.io.TFRecordWriter(output_file) for example in examples: tf_example = create_tf_example(example) writer.write(tf_example.SerializeToString()) writer.close() label_map = {...} # 标签映射 output_file = 'voc2007_train.tfrecord' create_tf_record(output_file) ``` 其`create_tf_example`函数将一个VOC2007样本转换为TFRecord Example,`create_tf_record`函数将整个VOC2007数据集转换为TFRecord文件。在这个例子,我们假设VOC2007数据集已经被解压缩到`VOCdevkit/VOC2007`目录下,标签映射已经定义为`label_map`变量。你需要根据自己的实际情况修改这些变量。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值