tf.dataset.Dataset用过多次了,一直懒得记录,今天来更新下
定义阶段
import tensorflow as tf
import numpy as np
...
def image_read(imnames,labs = None):
print(imnames)
images = tf.read_file('./reptile/combine_total_224/'+imnames)
images = tf.image.decode_jpeg(images, channels=3)
images = tf.image.resize_images(images, [224, 224])
images = tf.cast(images, tf.float32) / 255.0
images = tf.multiply(images, 2)
images = tf.subtract(images, 1.0)
return images,labs
with tf.name_scope('input'):
image_x = tf.placeholder(imagename_array.dtype,[None],name='image_x')
label_y = tf.placeholder(label_agen_array.dtype,[None,7,7,6],name='label_y')
istraing = tf.placeholder(tf.bool,name='istraing')
batch_size = 32
dataset = tf.data.Dataset.from_tensor_slices((image_x, label_y)) #对placeholder元组进行切片
dataset = dataset.map(image_read) #映射处理函数
dateset = dataset.repeat(1) #整个数据集只能轮询使用一次
dataset = dataset.batch(batch_size) #定义批次大小
iterator = dataset.make_initializable_iterator() #一种初始化方式
image_batch,label_batch = iterator.get_next() #获取下一个批次
1.因为我们每次填充placeholder都是以一个batch填充, 而tf.data.Dataset.from_tensor_slices((image_x, label_y)) #对placeholder元组进行切片;
2. 在图像识别中,我们的样本可能是文件名,通过dataset.map(image_read)可映射处理函数将文件名转化并归一化的像素;
3. dateset = dataset.repeat(n) 表示对整个数据集只能轮询使用n次,轮询完后会抛出异常;
4. dataset.make_initializable_iterator() 是dataset的一种初始化方式,还有其他初始化方式;
5. image_batch,label_batch = iterator.get_next() 表示获取下一个批次,然后我们在图中使用image_batch,label_batch,最后sess.run(),就会自动获取下一个批次;
6. 此外还有将数据随机打乱的方式,详细请查看其他资料。
使用阶段
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epochs in range(100):
print('epochs:',epochs)
sess.run(iterator.initializer,feed_dict={image_x: imagename_array,
label_y: label_agen_array})
while True:
try:
u = sess.run(optimizer,{istraing:True})
except tf.errors.OutOfRangeError:
break
- sess.run(iterator.initializer,feed_dict={image_x: imagename_array,
label_y: label_agen_array})填充样本数组上去; - 使用try和except捕获轮询完毕后的异常,然后就可以训练下一个epochs了。