tf2.0之tfrecord的使用记录
将ImageNet图像写入.tfrecord文件
记录:
1.tf.io.read_file()读图像文件,dtype是string,class为及时张量,写入的时候需要调用自身的.numpy()方法转换为一个bytes类型的numpy数组,使用tf.io.decode_jpeg()可以将两者解码为dtype是unt8的及时张量;
2.TFRecordWriter()与FixedLenFeature()整合到了tf.io模块中;
3.写入过程总结:一个Example包含一个Features,一个Features包含若干Feature,Feature以键值对存在并且有三种类型。内容填入好后用Example.SerializeToString()序列化,用writer(TFRecordWriter())写入。
import os
import tensorflow as tf
"""
message Example {
Features features = 1;
};
message Features {
map<string, Feature> feature = 1;
};
message Feature {
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
"""
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# def float_feature(value):
# return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def gen_tfrecord(img_dir,dst):
# “dst” is the path of tfrecord file.
# channel last
img_lsit = [os.path.join(root,file) for root,dirs,files in os.walk(img_dir) for file in files]
with tf.io.TFRecordWriter(dst) as writer:
for img in img_lsit:
str_img = tf.io.read_file(img) # encode image including pixel / shape / type(unt8).
name = bytes(img.split('\\')[-1],encoding='utf-8') # windows:'\\',linux:'/'.
feature = {
"name":bytes_feature(name),
"data":bytes_feature(str_img.numpy())
}
features = tf.train.Features(feature=feature)
example = tf.train.Example(features=features)
writer.write(example.SerializeToString())