tfrecord与图片格式互转

本文介绍了如何将jpg图片通过txt文件索引转换为TFRecord格式,包括读取txt文件、构建TFRecord样本和解码tfrecord,以便于模型训练和使用PyTorch DataLoader。同时,还展示了如何将解码后的图片组织到对应标签的文件夹结构中以支持ImageFolder。
摘要由CSDN通过智能技术生成

tfrecord与image互转

image2tfrecord

将.jpg文件转换成tfrecord文件

jpg文件准备:

1、一个保存图片名字的txt

2、图像命名0572721/002.jpg,"/"之前是标签

"""
train_set.txt as    {0572721/002.jpg
                    3299616/053.jpg
                    0136797/307.jpg
                    0005109/289.jpg
                    4203692/019.jpg
                    ...}
"""
def main():
	dataset = parse_txt(train_txt) #从txt中得到图像名字
	labels = [x.strip().split('/')[0] for x in dataset] #图像"/"之前是标签
	labels_set = set(labels) 
	classes_table = {x:y for y,x in enumerate(labels_set)} #存储是label1:0,label2:1,...
	writer = tf.io.TFRecordWriter(TrainTFPath) #存储到TrainTFPath,Path精确到了.tfrecord
	print("saving train dataset")

	for img_path in tqdm.tqdm(dataset):
    	tf_example = build_example(img_path,classes_table) #根据自己设置的格式转成tfrecord
    	writer.write(tf_example.SerializeToString()) #写入tfrecord文件
	writer.close()

设置tfrecord格式

def build_example(img_path,classes_table):
    image_path = os.path.join(SetPath,img_path) #只精确到了jpg上一个文件夹
    b_image = open(image_path,'rb').read()
    label = classes_table[img_path.split("/")[0]] #image_path[0]就是class_label中的label

    example = tf.train.Example(features=tf.train.Features(feature={
        'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[b_image])),
        'image/label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
    })) #设置tfrecord格式imagecode+label
    return example

最后得到tfrecord就是imagecode+label
使用tfrecord数据集时,这个dataloader就能直接送入训练模型

train_datasets = dataloader.load_tfrecord_dataset("./data/train.tfrecord")
def load_tfrecord_dataset(tf_path): #利用这个函数可以直接将tfrecord加载到dataset中
    datasets = tf.data.TFRecordDataset(tf_path) #读取.tfrecord文件
    datasets = datasets.repeat() 
    #重复率dataset.repeat(num),num为空表示无限重复下去,不设置则表示只重复一次
    
    return datasets.map(lambda x: parse_tfrecord(x)) #parse_tfrecord(x)是自己解码函数,这样读出来的数据才能用

tfrecord2image

#解码tfrecod格式
IMAGE_FEATURE_MAP = {
    'image/encoded': tf.io.FixedLenFeature([], tf.string),
    'image/label': tf.io.FixedLenFeature([], tf.int64),
}

将tfrecord转成(image,label),label格式

def parse_tfrecord(single_record):
    x = tf.io.parse_single_example(single_record, IMAGE_FEATURE_MAP)#每次从tfrecord中取一条数据,并且是按tfrecord方式取
    
    x_train = tf.image.decode_jpeg(x['image/encoded'], channels=3) #解码成jpeg,返回tensor
    x_train = x_train/255 #编码成16进制
    y_train = tf.stack([tf.cast(x['image/label'], 'int32')])

    return (x_train, y_train), y_train

如果我想进一步把解码出来的图片复原,将同一个label做一个文件夹,将jpg文件放到这个文件夹下,这样pytorch的ImageFloder就可以直接用这个文件;

#解码成jpeg,返回tensor[height,width,channels]
tf.image.decode_jpeg(x['image/encoded'], channels=3) 

现在我们用的代码是tensorflow v1,因为v2代码会出现如下的错误

Tensor.graph is meaningless when eager execution is enabled
#2.0为了更安全,有一个默认的eager execution is enable by default
#tensorflow的Eager Execution是一种命令式变成环境,可立即评估运算,无需构建计算图

首先根据我们编码tfrecord方式解码

IMAGE_FEATURE_MAP = {
    'image/encoded': tf.io.FixedLenFeature([], tf.string),
    'image/label': tf.io.FixedLenFeature([], tf.int64),
}

def parse_tfrecord(single_record):
    x = tf.io.parse_single_example(single_record, IMAGE_FEATURE_MAP)
    image=tf.image.decode_jpeg(x['image/encoded'], channels=3) #image格式[height,width,channels]
    label=tf.cast(x['image/label'], tf.int64) #label格式int
    return label, image

然后我们读入tfrecord文件,将其保存为image

def tfrecord2image(path_res):
    data_root = os.path.abspath(os.path.join(os.getcwd()))
    # print('tfrecords_files to be transformed:', path_res)
    data = tf.data.TFRecordDataset(path_res) #将tfrecord读入到TFRecordDataset中

    data = data.map(parse_tfrecord) #设置读tfrecord方式

    iterator = data.make_one_shot_iterator() #利用迭代器一条一条读tfrecord数据
    
    labels, images = iterator.get_next()#读取下一条记录

    with tf.Session() as sess: #多线程
        # start multi-threads
        sess.run(tf.local_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        i=0
        while not coord.should_stop():
            try:
                label, img = sess.run([labels, images]) #这是我们解码得到的label,image
            except tf.errors.OutOfRangeError:
                print("Turn to next folder.")
                break
        
            image_path = os.path.join(data_root, "train", str(label)) #得到图片存储路径
            floder = os.path.exists(image_path)
            if not floder:
                os.mkdir(image_path) #新建以label为名的文件夹
           
            img=tf.image.encode_jpeg(img,format="rgb") #将[height,width,channels]转成type为string的向量
            i+=1 #为了改变图片存储名字
            with tf.gfile.GFile(image_path + '/' + '{}.jpeg'.format(i), 'wb') as f:
                f.write(img.eval()) #将image写入该文件

        print("----------transform is finished.----------")
        coord.request_stop()
        coord.join(threads)

tf部分函数

中间涉及的部分函数

tf.TFRecordReader.read(queue,name=None)
#返回读取器生成的下一个记录(key,value)对,key、value都是字符串向量
tf.cast(x, dtype, name=None)
#将x转换成dtype
tf.io.decode_raw(input_bytes, out_type,  little_endian=True, fixed_length=None, name=None)
#将输入翻译成一系列bytes,然后转化成out_type,out_type:tf.half, tf.float32, tf.float64, tf.int32, tf.uint16, tf.uint8, tf.int16, tf.int8, tf.int64
tf.io.FixedLenFeature(shape, dtype, default_value=None)
#用于解析固定长度输入要素的配置,shape是输入数据形状,dtype是输入数据类型
#default_value:当输入缺失时,需要这个来填补否则报错

tf.stack(values, axis=0, name='stack')
#将秩为R的向量列表堆叠成一个秩为(R+1)的向量,沿axis维度打包,name是此操作的名称
#axis=0,(A,B,C)变为(N,A,B,C)
#axis=1,(A,B,C)变为(A,N,B,C)
#axis超过范围会报错

#将y转化为类似one-hot编码
tf.keras.utils.to_categorical(y, num_class, dtype)

#将三维向量输出为string type向量
tf.iamge.encode_jpeg(
    image, #3维向量[height,width,channels]
    format=None, #覆盖编码输出的颜色格式,"rgb"channels=3/""/"grayscale"channels=1
    quality=None,
    progressive=None,
    optimize_size=None,
    chroma_downsampling=None,
    density_unit=None,
    x_density=None,
    y_density=None,
    xmp_metadata=None,
    name=None
)

tf.data

tf.data.TFRecordDataset支持tf.data能以序列化的方式呈现,专门读取tfrecord格式数据集;

#读取tfrecord文件,filenames变量可以是a string, a list of strings, or a tf.Tensor of strings.
dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])

raw_example = next(iter(dataset)) #一条一条读取tfrecord数据
parsed = tf.train.Example.FromString(raw_example.numpy()) #序列化例子

parsed.features.feature['image/text']#对tfrecord数据进行解码

#以func方式解码dataset一条数据
dataset.map(func)

#利用迭代器读数据
iterator=dataset.make_one_shot_iterator()
next_element=iterator.get_next()

tf eager execution

tf v2默认开启eager execution。

eager模式是命令式编程,写好程序之后不需要编译就可以直接运行;

静态图模式就类似于c/c++的声明式变成,写好程序之后需要先编译,再运行;

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小橘AI

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值