TFRecords制作

为了高效地读取数据,可以将数据进行序列化存储,这样也便于网络流式读取数据。TFRecord是一种比较常用的存储二进制序列数据的方法

  • tf.Example类是一种将数据表示为{“string”: value}形式的meassage类型,Tensorflow经常使用tf.Example来写入、读取TFRecord数据
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import tensorflow as tf

# 通常情况下,tf.Example中可以使用以下几种格式:
# tf.train.BytesList: 可以使用的类型包括 string和byte
# tf.train.FloatList: 可以使用的类型包括 float和double
# tf.train.Int64List: 可以使用的类型包括 enum,bool, int32, uint32, int64

def _bytes_feature(value):
    """Returns a bytes_list from a string/byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Return a float_list form a float/double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Return a int64_list from a bool/enum/int/uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# tf.train.BytesList
print(_bytes_feature(b'test_string'))
print(_bytes_feature('test_string'.encode('utf8')))

# tf.train.FloatList
print(_float_feature(np.exp(1)))

# tf.train.Int64List
print(_int64_feature(True))
print(_int64_feature(1))

# tfrecord制作方法

def serialize_example(feature0, feature1, feature2, feature3):
    """
    创建tf.Example
    """

    # 转换成相应类型
    feature = {
        'feature0': _int64_feature(feature0),
        'feature1': _int64_feature(feature1),
        'feature2': _bytes_feature(feature2),
        'feature3': _float_feature(feature3),
    }
    # 使用tf.train.Example来创建
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    # SerializeToString方法转换为二进制字符串
    return example_proto.SerializeToString()

# 数据量
n_observations = int(1e4)

# Boolean feature
feature0 = np.random.choice([False, True], n_observations)

# Integer feature
feature1 = np.random.randint(0, 5, n_observations)

# String feature
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]

# Float feature
feature3 = np.random.randn(n_observations)

filename = 'tfrecord-1'

with tf.io.TFRecordWriter(filename) as writer:
    for i in range(n_observations):
        example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
        writer.write(example)

# 加载tfrecord文件
filenames = [filename]

# 读取
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset

# 图像数据处理实例

import os
import glob
from datetime import datetime

import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

image_path = '../img/'
images = glob.glob(image_path + '*.jpg')

for fname in images:
    image = mpimg.imread(fname)
    f, (ax1) = plt.subplots(1, 1, figsize=(8, 8))
    f.subplots_adjust(hspace=.2, wspace=.05)

    ax1.imshow(image)
    ax1.set_title('Image', fontsize=20)

image_labels = {
    'dog': 0,
    'kangaroo': 1,
}

# 读数据,binary格式
image_string = open('./img/dog.jpg', 'rb').read()
label = image_labels['dog']


def _bytes_feature(value):
    """Returns a bytes_list from a string/byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Return a float_list form a float/double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Return a int64_list from a bool/enum/int/uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# 创建图像数据的Example
def image_example(image_string, label):
    image_shape = tf.image.decode_jpeg(image_string).shape

    feature = {
        'height': _int64_feature(image_shape[0]),
        'width': _int64_feature(image_shape[1]),
        'depth': _int64_feature(image_shape[2]),
        'label': _int64_feature(label),
        'image_raw': _bytes_feature(image_string),
    }

    return tf.train.Example(features=tf.train.Features(feature=feature))

#打印部分信息
image_example_proto = image_example(image_string, label)

for line in str(image_example_proto).split('\n')[:15]:
    print(line)
print('...')

# 制作 `images.tfrecords`.

image_path = './img/'
images = glob.glob(image_path + '*.jpg')
record_file = 'images.tfrecord'
counter = 0

with tf.io.TFRecordWriter(record_file) as writer:
    for fname in images:
        with open(fname, 'rb') as f:
            image_string = f.read()
            label = image_labels[os.path.basename(fname).replace('.jpg', '')]

            # `tf.Example`
            tf_example = image_example(image_string, label)

            # 将`tf.example` 写入 TFRecord
            writer.write(tf_example.SerializeToString())

            counter += 1
            print('Processed {:d} of {:d} images.'.format(
                counter, len(images)))

print(' Wrote {} images to {}'.format(counter, record_file))

# 加载制作好的TFRecord

raw_train_dataset = tf.data.TFRecordDataset('images.tfrecord')
raw_train_dataset

# example数据都进行了序列化,还需要解析以下之前写入的序列化string
# tf.io.parse_single_example(example_proto, feature_description)函数可以解析单条example

# 解析的格式需要跟之前创建example时一致
image_feature_description = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'depth': tf.io.FixedLenFeature([], tf.int64),
    'label': tf.io.FixedLenFeature([], tf.int64),
    'image_raw': tf.io.FixedLenFeature([], tf.string),
}

# 现在看起来仅仅完成了一个样本的解析,实际数据不可能一个个来写吧,可以定义一个映射规则map函数

def parse_tf_example(example_proto):
    # 解析出来
    parsed_example = tf.io.parse_single_example(example_proto, image_feature_description)

    # 预处理
    x_train = tf.image.decode_jpeg(parsed_example['image_raw'], channels=3)
    x_train = tf.image.resize(x_train, (416, 416))
    x_train /= 255.

    lebel = parsed_example['label']
    y_train = lebel

    return x_train, y_train


train_dataset = raw_train_dataset.map(parse_tf_example)
train_dataset

# 制作训练集

num_epochs = 10

train_ds = train_dataset.shuffle(buffer_size=10000).batch(2).repeat(num_epochs)
train_ds

model = tf.keras.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(2, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=['accuracy'])

model.fit(train_ds, epochs=num_epochs)

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值