第一篇博客,用来整理之前写论文做实验遇到的小问题和解决方法,本文环境为tensorflow-gpu 2.5.0。
使用tfrecords原因
由于实验中使用CNN网络,图像画幅为1280*1024较大,为了提高网络模型的训练速度,不得以将数据集做成tfrecords的形式。本文主要介绍制作自己的tfrecords并在模型中作为数据使用。
制作tfrecords
代码如下:
def create_tfrecords():
record_file_name = '../tfrecords/0.4_train{}.tfrecords'.format(length)#tfrecords文件名
writer = tf.compat.v1.python_io.TFRecordWriter(record_file_name)#创建一个writer对象,将后续一个一个写好的feature放入writer
for rate in rateList:
imgname = '1.raw'
label = 1
psnr = 30
image_raw = open(imgname, mode='rb')#imgname是图像名/地址
image_bytes = image_raw.read(1310720);#图像大小为1280*1024字节image_bytes是一维数组
feature = { # 建立 tf.train.Feature 字典
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes])), # 图片是一个 Bytes 对象
'rate': tf.train.Feature(float_list=tf.train.FloatList(value=[rate])),
'label': tf.train.Feature(float_list=tf.train.FloatList(value=[label])), # 标签是一个float 对象
'psnr': tf.train.Feature(float_list=tf.train.FloatList(value=[psnr]))
}#设计feature的字典格式
example = tf.train.Example(features=tf.train.Features(feature=feature)) # 通过字典建立 Example,这个example可添加入writer
writer.write(example.SerializeToString())#将example序列化,放入writer
writer.close()#关闭writer
测试,验证,训练对应的tfrecords文件创建好了之后,就可以在模型中作为输入使用了。
读取tfrecords
读取tfrecords时,需要将数据一条一条取出来,然后将第一条数据解析为对应‘image','rate','label','psnr'标签的数据,类似example['label'],便可。
def parse_tf_img(example_proto):#解析器,将tfrecords中的一条解析为一个example
image_feature_description = {
'label': tf.io.FixedLenFeature([], tf.float32),
'rate' : tf.io.FixedLenFeature([], tf.float32),
'image': tf.io.FixedLenFeature([], tf.string),
}#由于实验中暂时用不到'psnr'数据,所以不需要把它解析出来,这样解析的画example中就只包含‘image','rate','label'.
# 解析出来
parsed_example = tf.io.parse_single_example(example_proto, image_feature_description)
y = parsed_example['label']
image = parsed_example['image']#image是tf.string格式,需要将其解码为bytes格式
image = tf.compat.v1.decode_raw(image, tf.uint8)#将image解码为bytes,uint8类型,类似数组
image = tf.reshape(image, [1280, 1024, 1]) #将一维数组转化为1280,1024,1的矩阵
image = tf.cast((image-tf.reduce_mean(image)) / (tf.reduce_max(image)-tf.reduce_min(image)), tf.float32)#将图像归一化。
y = tf.cast(parsed_example['label'], tf.float32)
return image,y #image和y都属于tf.tensor
#调用解析函数,读取tfrecords。
def read_tfrecords():
a = time.time()#记录时间
tffile = 'train30000.tfrecords'
raw_train_dataset = tf.data.TFRecordDataset(tffile)#将tfrecords的一条条数据读取出来
train_dataset = raw_train_dataset.map(parse_tf_img)#将tfrecords的一条条数据解析为example['image'],example['label'],example['rate'],这是一个迭代器,在真正需要使用下一条数据的时候才处理解析。
for x,y in train_dataset:#x对应image,y对应label。
print(type(x),type(y),x,y)
b = time.time()#记录时间
print("%.4f" % (b - a))
由于训练数据是图像的原因,所以必须要使用tf.compat.v1.decode_raw(image, tf.uint8)#将image解码为bytes,uint8类型的(0-255)区间。decode_raw的作用是将string转为bytes。
使用tfrecords数据训练模型
def train:
tffile = 'train30000.tfrecords'
val_tffile = 'val30000.tfrecords'
raw_train_dataset = tf.data.TFRecordDataset(tffile)
train_dataset = raw_train_dataset.map(parse_tf_img)
train_dataset = train_dataset.shuffle(buffer_size=10) # 在缓冲区中随机打乱数据
train_batch = train_dataset.batch(batch_size=64) #数据分为batch大小为64的批训练。
#验证集
raw_val_dataset = tf.data.TFRecordDataset(val_tffile)
val_dataset = raw_train_dataset.map(parse_tf_img)
val_dataset = val_dataset.shuffle(buffer_size=10) # 在缓冲区中随机打乱数据
val_batch = val_dataset.batch(batch_size=64) #数据分为batch大小为64的批训练。
model=MODEL;#model是一个简单的CNN网络.使用kears构建
history = model.fit(train_batch,
validation_data=val_batch,
epochs=40)#将train_batch作为训练集,将val_batch作为验证集,训练40轮。
这样就可以直接使用自己创建的tfrecords数据集了。