制作自己的tfrecords并在keras代码中使用(tensorflow相关)

第一篇博客,用来整理之前写论文做实验遇到的小问题和解决方法,本文环境为tensorflow-gpu 2.5.0。

使用tfrecords原因

由于实验中使用CNN网络,图像画幅为1280*1024较大,为了提高网络模型的训练速度,不得以将数据集做成tfrecords的形式。本文主要介绍制作自己的tfrecords并在模型中作为数据使用。

制作tfrecords

代码如下:

def create_tfrecords():
    record_file_name = '../tfrecords/0.4_train{}.tfrecords'.format(length)#tfrecords文件名
    writer = tf.compat.v1.python_io.TFRecordWriter(record_file_name)#创建一个writer对象,将后续一个一个写好的feature放入writer
    for rate in rateList:
            
       imgname = '1.raw'
       label = 1
       psnr = 30
       image_raw = open(imgname, mode='rb')#imgname是图像名/地址
       image_bytes = image_raw.read(1310720);#图像大小为1280*1024字节image_bytes是一维数组


       feature = {  # 建立 tf.train.Feature 字典
                    'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes])),  # 图片是一个 Bytes 对象
                    'rate': tf.train.Feature(float_list=tf.train.FloatList(value=[rate])),
                    'label': tf.train.Feature(float_list=tf.train.FloatList(value=[label])),  # 标签是一个float 对象
                    'psnr': tf.train.Feature(float_list=tf.train.FloatList(value=[psnr]))
                }#设计feature的字典格式
       example = tf.train.Example(features=tf.train.Features(feature=feature)) # 通过字典建立 Example,这个example可添加入writer
       writer.write(example.SerializeToString())#将example序列化,放入writer
    writer.close()#关闭writer

测试,验证,训练对应的tfrecords文件创建好了之后,就可以在模型中作为输入使用了。

读取tfrecords

读取tfrecords时,需要将数据一条一条取出来,然后将第一条数据解析为对应‘image','rate','label','psnr'标签的数据,类似example['label'],便可。

def parse_tf_img(example_proto):#解析器,将tfrecords中的一条解析为一个example
    image_feature_description = {
        'label': tf.io.FixedLenFeature([], tf.float32),
        'rate' : tf.io.FixedLenFeature([], tf.float32),
        'image': tf.io.FixedLenFeature([], tf.string),
    }#由于实验中暂时用不到'psnr'数据,所以不需要把它解析出来,这样解析的画example中就只包含‘image','rate','label'.
    # 解析出来
    parsed_example = tf.io.parse_single_example(example_proto, image_feature_description)

    y = parsed_example['label']
    image = parsed_example['image']#image是tf.string格式,需要将其解码为bytes格式
    image = tf.compat.v1.decode_raw(image, tf.uint8)#将image解码为bytes,uint8类型,类似数组
    image = tf.reshape(image, [1280, 1024, 1])    #将一维数组转化为1280,1024,1的矩阵
    image = tf.cast((image-tf.reduce_mean(image)) / (tf.reduce_max(image)-tf.reduce_min(image)), tf.float32)#将图像归一化。
    y = tf.cast(parsed_example['label'], tf.float32)

    return image,y #image和y都属于tf.tensor






#调用解析函数,读取tfrecords。
def read_tfrecords():
    

    a = time.time()#记录时间
    tffile = 'train30000.tfrecords'

    raw_train_dataset = tf.data.TFRecordDataset(tffile)#将tfrecords的一条条数据读取出来
    train_dataset = raw_train_dataset.map(parse_tf_img)#将tfrecords的一条条数据解析为example['image'],example['label'],example['rate'],这是一个迭代器,在真正需要使用下一条数据的时候才处理解析。

    for x,y in train_dataset:#x对应image,y对应label。
        print(type(x),type(y),x,y)


    b = time.time()#记录时间
    print("%.4f" % (b - a))

由于训练数据是图像的原因,所以必须要使用tf.compat.v1.decode_raw(image, tf.uint8)#将image解码为bytes,uint8类型的(0-255)区间。decode_raw的作用是将string转为bytes。

使用tfrecords数据训练模型

def train:
    tffile = 'train30000.tfrecords'
    val_tffile = 'val30000.tfrecords'

    raw_train_dataset = tf.data.TFRecordDataset(tffile)
    train_dataset = raw_train_dataset.map(parse_tf_img)

    train_dataset = train_dataset.shuffle(buffer_size=10)  # 在缓冲区中随机打乱数据
    train_batch = train_dataset.batch(batch_size=64)  #数据分为batch大小为64的批训练。

#验证集
    raw_val_dataset = tf.data.TFRecordDataset(val_tffile)
    val_dataset = raw_train_dataset.map(parse_tf_img)

    val_dataset = val_dataset.shuffle(buffer_size=10)  # 在缓冲区中随机打乱数据
    val_batch = val_dataset.batch(batch_size=64)  #数据分为batch大小为64的批训练。

    model=MODEL;#model是一个简单的CNN网络.使用kears构建

    history = model.fit(train_batch,
                    validation_data=val_batch,
                    epochs=40)#将train_batch作为训练集,将val_batch作为验证集,训练40轮。

这样就可以直接使用自己创建的tfrecords数据集了。

  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值