在进行深度学习开发之前,我们都必须面对的是数据加载问题。如何加载我们自己的数据,是我们不得不面对的一个问题,本篇以数据加载作为我们tensorflow实战的开始,教你手把手实现自己的模型训练。
目录
一、tensorflow常见的数据集格式
- 内存数据:该类数据集是通过直接读取数据,并通过注入机制来进行数据的加载,通常来说如果数据较多时会造成数据加载过慢,而且非常消耗内存。所以一般建议少量数据集时使用该方法;
- TFRecord数据:这种数据是通过队列管道的方式来加载数据,通常我们会将数据先制作成TFRecord格式然后在进行加载,这种数据加载模式非常适合有大量训练的数据。所以一般建议如果数据集较多时可以考虑这在方法;
- Dataset数据:特别强调这是下x1.4版本后的新特性,也是官方比较推荐的一种加载方法,他通过性能更高的输入管道进行加载数据,在后面部分我会着重介绍他的方法。这里,一般建议使用这种方法加载数据(tfrecord看起来太麻烦了,不知道大家有没有这种感受...);
- tf.keras接口数据:只支持keras语法的数据,这里不详细说明了。
看到这里大家会有一定的想法了,对于前两种方法,由于其局限性和一定的阅读困难,最主要是tfrecord写起来太麻烦了。
二、内存数据
在读取图片的过程中,如果图像数据集较小,则可以直接全部读取,如果数据集较多,则可能消耗大量的内存。这时候可以考虑边读边取,但如果频繁的进行读取操作可能会影响性能。这里可以采用队列的方式进行,即使用两个线程并发:一个线程用于取数据进行训练;一个线程用于读数据到内存中。
2.1、数据集说明
下载链接:https://download.pytorch.org/tutorial/hymenoptera_data.zip
这是取自于imageNet的非常小的子集。其训练集和验证集的数目见下表:
类别 | 训练集 | 验证集 |
---|---|---|
蜜蜂 | 121 | 83 |
蚂蚁 | 124 | 7 |
数据集的结构如下:
2.2、生成样本数据
2.2.1、加载文件路径和标签
读取文件夹下的图片路径和对应的标签,并存储到lfilenames 和labelsnames的list中,并对其进行shuffle操作。
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import os
import numpy as np
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
def load_sample(sample_dir,shuffleflag = True):
'''递归读取文件。只支持一级。返回文件名、数值标签、数值对应的标签名'''
print ('loading sample dataset..')
lfilenames = []
labelsnames = []
for (dirpath, dirnames, filenames) in os.walk(sample_dir):#递归遍历文件夹
for filename in filenames: #遍历所有文件名
#print(dirnames)
filename_path = os.sep.join([dirpath, filename])
lfilenames.append(filename_path) #添加文件名
labelsnames.append( dirpath.split('\\')[-1] )#添加文件名对应的标签
lab= list(sorted(set(labelsnames))) #生成标签名称列表
labdict=dict( zip( lab ,list(range(len(lab))) )) #生成字典
labels = [labdict[i] for i in labelsnames]
if shuffleflag == True:
return shuffle(np.asarray( lfilenames),np.asarray( labels)),np.asarray(lab)
else:
return (np.asarray( lfilenames),np.asarray( labels)),np.asarray(lab)
2.2.1、队列中批次读取数据
读取批次数据的具体步骤:
- 使用tf.train.slice_input_producer函数生成队列;
- 加载数据并进行预处理;
- 使用tf.train.batch将预处理后的数据变成批次数据。
def get_batches(image,label,input_w,input_h,channels,batch_size):
queue = tf.train.slice_input_producer([image,label]) #使用tf.train.slice_input_producer实现一个输入的队列
label = queue[1] #从输入队列里读取标签
image_c = tf.read_file(queue[0]) #从输入队列里读取image路径
image = tf.image.decode_bmp(image_c,channels) #按照路径读取图片
image = tf.image.resize_image_with_crop_or_pad(image,input_w,input_h) #修改图片大小
image = tf.image.per_image_standardization(image) #图像标准化处理,(x - mean) / adjusted_stddev
image_batch,label_batch = tf.train.batch([image,label],#调用tf.train.batch函数生成批次数据
batch_size = batch_size,
num_threads = 64)
images_batch = tf.cast(image_batch,tf.float32) #将数据类型转换为float32
labels_batch = tf.reshape(label_batch,[batch_size])#修改标签的形状shape
return images_batch,labels_batch
2.2.3、在Session中使用数据集
通过在静态图Session中启动一个带有协调器的队列线程来获取数据,具体如下:
(image,label),labelsnames = load_sample(data_dir) #载入文件名称与标签
batch_size = 16
image_batches,label_batches = get_batches(image,label,28,28,1,batch_size)
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init) #初始化
coord = tf.train.Coordinator() #开启列队
threads = tf.train.start_queue_runners(sess = sess,coord = coord)
try:
for step in np.arange(10):
if coord.should_stop():
break
images,label = sess.run([image_batches,label_batches]) #注入数据
except tf.errors.OutOfRangeError:
print("Done!!!")
finally:
coord.request_stop()
coord.join(threads) #关闭列队
完整代码链接:https://github.com/kingqiuol/learning_tensorflow/blob/master/data/load_imagedata.py
执行上述完整代码后的结果如下:
三、TFRecord数据
TFRecord格式的文件存储形式会很合理的帮我们存储数据。TFRecord内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率。
3.1、生成TFRecord数据集
这里我们使用第二节中的数据集进行演示,同时使用load_sample函数来加载数据集的路径和标签数据,在此可以参考上一节。在使用TFRecord数据集之前我们需要将我们的数据制作成TFRecord的格式方便后续的训练。具体实现流程如下:
- 按照load_sample读取的数据进行读取图片;
- 将读取的图片和标签进行打包组合在一起;
- 使用TFRecordWriter对象的write方法将图片和标签写入文件中。
def makeTFRec(filenames,labels): #定义函数生成TFRecord
writer= tf.python_io.TFRecordWriter("mydata.tfrecords") #通过tf.python_io.TFRecordWriter 写入到TFRecords文件
for i in tqdm( range(0,len(labels) ) ):
img=Image.open(filenames[i])
img = img.resize((256, 256))
img_raw=img.tobytes()#将图片转化为二进制格式
example = tf.train.Example(features=tf.train.Features(feature={
#存放图片的标签label
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[i]])),
#存放具体的图片
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
})) #example对象对label和image数据进行封装
writer.write(example.SerializeToString()) #序列化为字符串
writer.close() #数据集制作完成
3.2、在队列中批量读取数据
通常在训练集中我们需要对数据进行乱序操作,并按指定的方法组合,而在测试集中,只需一次加载,不需要乱序和批次组合。具体实现如下:
def read_and_decode(filenames,flag = 'train',batch_size = 3):
#根据文件名生成一个队列
if flag == 'train':
filename_queue = tf.train.string_input_producer(filenames)#默认已经是shuffle并且循环读取
else:
filename_queue = tf.train.string_input_producer(filenames,num_epochs = 1,shuffle = False)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
features = tf.parse_single_example(serialized_example, #取出包含image和label的feature对象
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})
#tf.decode_raw可以将字符串解析成图像对应的像素数组
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [256,256,3])
#
label = tf.cast(features['label'], tf.int32)
if flag == 'train':
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 #归一化
img_batch, label_batch = tf.train.batch([image, label], #还可以使用tf.train.shuffle_batch进行乱序批次
batch_size=batch_size, capacity=20)
return img_batch, label_batch
return image, label
3.3、在Session中使用数据集
TFRecordfilenames = ["mydata.tfrecords"]
image, label =read_and_decode(TFRecordfilenames,flag='test') #以测试的方式打开数据集
#开始一个会话读取数据
with tf.Session() as sess:
sess.run(tf.local_variables_initializer()) #初始化本地变量,没有这句会报错
#启动多线程
coord=tf.train.Coordinator()
threads= tf.train.start_queue_runners(coord=coord)
try:
while True:
example, examplelab = sess.run([image,label])#在会话中取出image和label
except tf.errors.OutOfRangeError:
print('Done Test -- epoch limit reached')
finally:
coord.request_stop()
coord.join(threads)
print("stop()")
完整代码链接:https://github.com/kingqiuol/learning_tensorflow/blob/master/data/tfrecord_imagedata.py
参考链接:
四、Dataset数据集
Dataset数据集是由tf.data.Dataset接口实现的,通过该接口使得tensorflow能够更加方便、快速的处理数据集。Dataset对象能直接对其上的数据进行相关乱序、迭代等操作(越来越和pytorch有点像了)。
Dataset对象的创建:
- tf.data.Dataset.from_tensors:根据内存对象创建,且只有一个元素;
- tf.data.Dataset.from_tensors_slices:根据内存对象创建,对象可以是list、dict、set、Numpy等;
- tf.data.Dataset.from_generator:根据迭代器生成对象。
这几种方法比较类似,一般用的比较多的是第二种方法,建议着重掌握这种方法。
Dataset对象支持的操作:
1、dataset.shuffle(buffer_size,seed=None,reshuffle_each_iteration=None):将数据内部的元素顺序打乱
- buffer_size:随机打乱元素排序的大小,一般越大越混乱。
- seed:随机种子,一般不用管。
- reshuffle_each_iteration:是否每次迭代都乱序。
2、dataset.repeat(count=None):生成重复的数据,count代表重复的次数
3、dataset.map(map_func,num_parallel_cell=None):通过map_func来将数据集中的每一个元素进行转换处理
- map_func:处理函数
- num_parallel_cell:并行处理的线程个数
4、dataset.batch(batch_size,drop_remainder):将数据集的元素按照批次进行组合
- batch_size:批次大小。
- drop_remainder:是否忽略批次组合后剩余的数据
5、dataset.prefetch(buffer_size):设置从数据集中取数据时的最大缓冲区。一般推荐将buffer_size设置为tf.data.experimental.AUTOTUNE,代表系统自动调节大小
一般来讲,处理数据的合理步骤为:创建Dataset对象->乱序数据集(shuffle)->重复数据集(repeat)->数据预处理(map)->设定批次(batch)->设定缓存(prefetch)。
在训练过程中有时会出现某次数据不足的情况,一般造成这种情况的原因有:数据总数不能被batch_size整除,而在训练过程中的剩余数据也会进入训练。解决的方法主要有:第一种方法是对数据进行repeat操作,在进行batch设置;第二种方法是将batch函数中的drop_remainder参数设置为True,这样在训练过程中就会丢弃剩余数据,从而避免批次数据不足的情况。
Dataset数据集的操作步骤:
(1)创建Dataset对象;
(2)对Dataset对象进行变换操作;
(3)创建Dataset迭代器;
(4)在会话Session中取数据。
4.1、生成Dataset对象
这里我们还是使用之前的数据集进行测试,同时使用load_sample函数来加载数据集的路径和标签数据,在此可以参考第二节。这里我们使用了上述的一些函数来进行创建,具体实现如下:
def _norm_image(image,size,ch=1,flattenflag = False): #定义函数,实现归一化,并且拍平
image_decoded = image/255.0
if flattenflag==True:
image_decoded = tf.reshape(image_decoded, [size[0]*size[1]*ch])
return image_decoded
def dataset(directory,size,batchsize,random_rotated=False):#定义函数,创建数据集
""" parse dataset."""
(filenames,labels),_ =load_sample(directory,shuffleflag=False) #载入文件名称与标签
def _parseone(filename, label): #解析一个图片文件
""" Reading and handle image"""
image_string = tf.read_file(filename) #读取整个文件
image_decoded = tf.image.decode_image(image_string)
image_decoded = tf.image.resize(image_decoded, size) #变化尺寸
image_decoded = _norm_image(image_decoded,size)#归一化
image_decoded = tf.cast(image_decoded,dtype=tf.float32)
label = tf.cast( tf.reshape(label, []) ,dtype=tf.int32 )#将label 转为张量
return image_decoded, label
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))#生成Dataset对象
dataset = dataset.map(_parseone) #有图片内容的数据集
dataset = dataset.batch(batchsize) #批次划分数据集
return dataset
4.2、在Session中使用Dataset数据集
def getone(dataset):
iterator = dataset.make_one_shot_iterator() #生成一个迭代器
one_element = iterator.get_next() #从iterator里取出一个元素
return one_element
sample_dir=r"./hymenoptera_data/train"
size = [96,96]
batchsize = 10
tdataset = dataset(sample_dir,size,batchsize)
print(tdataset.output_types) #打印数据集的输出信息
print(tdataset.output_shapes)
one_element1 = getone(tdataset) #从tdataset里取出一个元素
with tf.Session() as sess: # 建立会话(session)
sess.run(tf.global_variables_initializer()) #初始化
try:
for step in np.arange(1):
image,label = sess.run(one_element1)
except tf.errors.OutOfRangeError: #捕获异常
print("Done!!!")
完整代码链接:https://github.com/kingqiuol/learning_tensorflow/blob/master/data/dataset_imagedata.py
执行上述完整代码后的结果如下:
五、总结
这里我们对tensorflow数据的加载方式有了一定的了解,同时也建议大家在以后的模型搭建过程中尽量使用Dataset这种数据加载方式,在之后的章节中我将继续讲解数据加载的进阶,讲解在实际工程中的使用方法,哈哈哈,有兴趣的话可以用继续看看我的下一篇文章。