TF.keras + tfrecord

TF.keras + tfrecord

在工程中,模型常常需要训练大数据,而大数据的读取通常不能一次性读取进内存中,因此需要不断从数据集中读取数据并进行处理。在大数据中,这部分的耗时相当可观,因此可以利用tfrecord进行预先处理数据,节省读取和处理的时间。

使用tfrecord有几个问题:

1.如何将图像转为tfrecord格式。

2.如何读取tfrecord文件进行训练。

3.如何读取多个tfrecord文件进行训练。

图像转为tfrecord

这里需要注意的是,由于数据过大,不能在读取tfrecord的时候打乱数据,这样打乱数据不能充分打乱所有数据,因此,在保存tfrecord的时候就应该打乱数据,建议将图像名列表打乱后在按照图像名列表顺序保存进多个tfrecord中。

1.首先打开tfrecord文件:

tf.python_io.TFRecordWriter( path, options=None)

在tensorflow1.14版本以上也可以使用:

tf.io.TFRecordWriter
tf.io.TFRecordWriter(
    path, options=None
)
Args:
  • path: The path to the TFRecords file.
  • options: (optional) String specifying compression type, TFRecordCompressionType, or TFRecordOptions object.

具有如下属性:

close
close()

Close the file.

flush
flush()

Flush the file.

write
write(
    record
)

Write a string record to the file.

tf.train.Example

https://tensorflow.google.cn/api_docs/python/tf/train/Example?hl=en

tf.train.Features

https://tensorflow.google.cn/api_docs/python/tf/train/Example?hl=en

tf.train.Feature
Attributes:
  • bytes_list: BytesList bytes_list
  • float_list: FloatList float_list
  • int64_list: Int64List int64_list

https://tensorflow.google.cn/api_docs/python/tf/train/Example?hl=en

writer = tf.python_io.TFRecordWriter(os.path.join(tfrecord_save_path, ftrecordfilename))
img = Image.open(image_path, 'r')
img = img.resize((224, 224))
size = img.size

img_raw = img.tobytes()  # 将图片转化为二进制格式
example = tf.train.Example(
features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'img_width': tf.train.Feature(int64_list=tf.train.Int64List(value=[size[0]])),
'img_height': tf.train.Feature(int64_list=tf.train.Int64List(value=[size[1]]))
}))
writer.write(example.SerializeToString())  # 序列化为字符串
writer.close()

读取多个tfrecord文件进行训练

通过加载tfrecord文件的文件名,传入

tf.data.TFRecordDataset
Args:
  • filenames: A tf.string tensor or tf.data.Dataset containing one or more filenames.
  • compression_type: (Optional.) A tf.string scalar evaluating to one of "" (no compression), "ZLIB", or "GZIP".
  • buffer_size: (Optional.) A tf.int64 scalar representing the number of bytes in the read buffer. If your input pipeline is I/O bottlenecked, consider setting this parameter to a value 1-100 MBs. If None, a sensible default for both local and remote file systems is used.
  • num_parallel_reads: (Optional.) A tf.int64 scalar representing the number of files to read in parallel. If greater than one, the records of files read in parallel are outputted in an interleaved order. If your input pipeline is I/O bottlenecked, consider setting this parameter to a value greater than one to parallelize the I/O. If None, files will be read sequentially.
def get_dataset_batch(data_files):
    dataset = tf.data.TFRecordDataset(data_files)
    dataset = dataset.repeat()  # 重复数据集
    dataset = dataset.map(read_and_decode)  # 解析数据
    dataset = dataset.shuffle(buffer_size=100)  # 在缓冲区中随机打乱数据
    batch = dataset.batch(batch_size=4)  # 每10条数据为一个batch,生成一个新的Datasets
    return batch

其中调用了回调函数:read_and_decode

def read_and_decode(example_string):
    '''
    从TFrecord格式文件中读取数据
    '''
    features = tf.parse_single_example(example_string,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                           'img_width': tf.FixedLenFeature([], tf.int64),
                                           'img_height': tf.FixedLenFeature([], tf.int64)
                                       })

    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [224, 224, 3])
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(features['label'], tf.int64)
    label = tf.one_hot(label, 2)
    return img, label
最后进行训练
def train(model, batch):
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss="categorical_crossentropy",
        metrics=["accuracy"],
    )

    model.fit(batch, epochs=1, steps_per_epoch=10)
    return model
def get_model():
    model = tf.keras.applications.MobileNetV2(include_top=False, weights=None)
    inputs = tf.keras.layers.Input(shape=(224, 224, 3))
    x = model(inputs)  # 此处x为MobileNetV2模型去处顶层时输出的特征相应图。
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    outputs = tf.keras.layers.Dense(2, activation='softmax',
                                    use_bias=True, name='Logits')(x)
    model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    model.summary()
    return model

github代码: https://github.com/18150167970/TFrecord_tf_keras_demo

if useful:
	start work

have fun(笑)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值