转载自 https://blog.csdn.net/qq_36330643/article/details/77366083
关于TensorFlow读取数据,官网提供了3种方法:
1)Feeding : 在TensorFlow程序运行的每一个epoch,用python代码在线提供数据
2)Reader : 在一个计算图(tf.graph)的开始前,将文件读入到流(queue)中
3)在声明 tf.variable 变量或 numpy 数组时保存数据。受限于内存大小,适用于数据较小的情况
准备图片数据
现在用这两类图片制作TFRecords文件
制作TFRecords文件
1)TFRecords是一种将图像数据和标签放在一起的二进制文件,能更好的利用内存,在TensorFlow中快速的复制、移动、读取、存储等。
注意:存储数据时,TFRecords会根据你选择输入文件的类,自动给每一类打上同样的标签。在本例中,只有0,1两类。
2)制作TFRecords文件的代码
import tensorflow as tf
from PIL import Image
cwd = 'D:\tfrecords测试\dog\\'
classes = {'husky', 'chihuahua'}
def creat_tfrecords(tfrecords_name):
writer = tf.python_io.TFRecordWriter(tfrecords_name)
for index, name in enumerate(classes):
for img_name in os.listdir(class_path):
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((128, 128))
img_raw = img.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
"img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
调用以上函数后,会生成一个tfrecords文件,名字为函数的输入参数tfrecords_name
3)tf.train.Example 协议内存块包含了Features字段,通过feature将图片的二进制数据和label进行统一封装, 然后将example协议内存块转化为字符串, tf.python_io.TFRecordWriter 写入到TFRecords文件中。
读取TFRecords文件
1) 在制作完tfrecords文件后,将该文件读入到数据流中,代码如下:
def read_and_decode(filename):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string),
})
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [128, 128, 3])
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(20):
example, l = sess.run([image, label])
# 下面两行代码功能是 将读取到的img存储为.jpg图片
img = Image.fromarray(example, 'RGB')
img.save(cwd + str(i) + '_Label_' + str(l) + '.jpg')
print(example, l)
coord.request_stop()
coord.join(threads)
注意: example, l = sess.run([image, label]) 的功能是将tensor格式的image和label转换为array
2)调用以上函数后,会在文件夹中得到20张以一定规则命名的.jpg文件
读取一个batch的数据
1)我们训练神经网络的时候往往不是将数据一个一个的输入到网络里,而是输入一个batch的数据。将多个输入样例组织成一个batch可以提高模型的训练效率。
TensorFlow提供了 tf.train.batch 和 tf.shuffle_batch 函数来将单个的样例组织成batch的形式输出。它们唯一的区别在于是否会将数据的顺序打乱。
tf.train.batch 有三个输入参数,第一个是数据,第二个是batch_size,第三个是 capacity。
capacity表示组合样例的队列中最多可以存储的样例个数,这个队列如果太大,那么需要占用很多内存资源;如果太小,那么出队操作可能会因为没有数据而被阻碍(block),从而导致训练效率降低。一般来说这个队列的大小会和每一个batch的大小有关,一般可设置为 capacity = 1000 + 3 * batch_size
tf.train.shuffle_batch 比 tf.train.batch 多一个参数 min_after_dequeue.
min_after_dequeue 限制了出队时队列中元素的最少个数。当队列中元素太少时,随机打乱样例顺序的作用就不大了。当出队函数被调用但是队列中元素不够时,出队操作将等待更多的元素入队才会完成,如果 min_after_dequeue 参数被设定, capacity 也应该相应调整来满足性能要求。
下面代码说明读取一个batch的方法:
# 读取tfrecords文件的函数
def read_and_decode(filename):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features={
'label':tf.FixedLenFeature([], tf.int64),
'img_raw':tf.FixedLenFeature([], tf.string),
})
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [200, 200, 3])
label = tf.cast(features['label'], tf.int32)
return image, label
----------
example, label = read_and_decode('train.tfrecords')
# 读取一个batch的数据和标签
example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size=3, capacity=1000, min_after_dequeue=500)
# 开启一个会话
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# 取3次batch_size=3的数据,并打印出标签 看取数据的情况。
for i in range(3):
cur_example_batch, cur_label_batch = sess.run([example_batch, label_batch])
print(cur_label_batch)
coord.request_stop()
coord.join(threads)
将label转换为one hot形式
参考https://blog.csdn.net/a_yangfh/article/details/77911126
详细讲解参考上述博客
以手写数字识别为例,我们需要将0-9共十个数字标签转化成onehot标签。例如:数字标签“6”转化为onehot标签就是[0,0,0,0,0,0,1,0,0,0]
而我们通过tfrecords读取到的label并不是one hot 形式的,而是类似于[1, 2, 3, 4, 5]形式的,其中数字n代表第n类。
通过以下函数可以将我们读取到的label 转换成 one hot 形式:
import tensorflow as tf
# labels为需要转换的标签
# NUM_CLASSES为类别数
# batch_size为batch大小
def turn_to_onehot(labels, NUM_CLASSES, batch_size):
labels = tf.expand_dims(labels, 1) # 增加一个维度
indices = tf.expand_dims(tf.range(0, batch_size, 1), 1) # 生成索引
concated = tf.concat([indices, labels], 1) # 作为拼接
onehot_labels = tf.sparse_to_dense(concated, tf.stack([batch_size, NUM_CLASSES]), 1.0, 0.0) # 生成one-hot编码的标签
return onehot_labels