https://blog.csdn.net/weixin_30235225/article/details/95432515

1.简介

将数据划分成若干批次的数据,可以使用tf.train或者tf.data.Dataset中的方法。

1.1 tf.train

tf.train.slice_input_producer(tensor_list,shuffle=True,seed=None,capacity=32)

tf.train.batch(tensors,batch_size,num_threads=1,capacity=32,allow_smaller_final_batch=False)

参数说明:

shuffle:为True时进行数据清洗

allow_smaller_final_batch:为True时将小于batch_size的批次值输出



1.2 tf.data.Dataset

tf.data.Dataset是一个类,可以使用以下方法:

from_tensor_slices(tensors)

batch(batch_size,drop_remainder=False)

shuffle(buffer_size,seed=None,reshuffle_each_iteration=None)

repeat(count=None)

make_one_shot_iterator() / get_next()

注:make_one_shot_iterator() / get_next()用于Dataset数据的迭代器

参数说明:

tensors:可以是列表、字典、元组等类型

drop_remainder:为False时表示不保留小于batch_size的批次,否则删除

buffer_size:数据清洗时使用的buffer大小

count:对应为epoch个数,为None时表示数据序列无限延续

2.示例

2.1 使用tf.train.slice_input_producer和tf.train.batch

1 import tensorflow as tf
2 import numpy as np
3 import math
4
5 # 生成样例数据集
6 def generate_data():
7 num = 15
8 labels = np.asarray(range(num))
9 images = np.random.random([num, 5, 5, 3])
10 return images, labels
11
12 # 打印样例信息
13 images, labels = generate_data()
14 print(‘images.shape={0}, labels.shape={1}’.format(images.shape, labels.shape))
15
16 # 定义周期、批次、数据总量和遍历一次所有数据所需的迭代次数
17 n_epochs = 3
18 batch_size = 6
19 train_nums = 15
20 iterations = math.ceil(train_nums/batch_size)
21
22 # 使用tf.train.slice_input_producer将所有数据放入队列,使用tf.train.batch划分队列中的数据
23 input_queue = tf.train.slice_input_producer([images, labels], shuffle=False)
24 image_batch, label_batch = tf.train.batch(input_queue, batch_size=batch_size, num_threads=1, capacity=32)
25 print(‘image_batch.shape={0}, label_batch.shape={1}’.format(image_batch.shape, label_batch.shape))
26
27
28 with tf.Session() as sess:
29 tf.global_variables_initializer().run()
30 # 启动队列线程
31 coord = tf.train.Coordinator()
32 threads = tf.train.start_queue_runners(sess, coord)
33 # 打印信息
34 for epoch in range(n_epochs):
35 for iteration in range(iterations):
36 cu_image_batch, cu_label_batch = sess.run([image_batch, label_batch])
37 print(‘The {0} epoch, the {1} iteration, current batch is {2}’.format(epoch+1,iteration+1,cu_label_batch))
38 # 接收线程
39 coord.request_stop()
40 coord.join(threads)
41
42
43 # 打印结果如下
44 images.shape=(15, 5, 5, 3), labels.shape=(15,)
45 image_batch.shape=(6, 5, 5, 3), label_batch.shape=(6,)
46 The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
47 The 1 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11]
48 The 1 epoch, the 3 iteration, current batch is [12 13 14 0 1 2]
49 The 2 epoch, the 1 iteration, current batch is [3 4 5 6 7 8]
50 The 2 epoch, the 2 iteration, current batch is [ 9 10 11 12 13 14]
51 The 2 epoch, the 3 iteration, current batch is [0 1 2 3 4 5]
52 The 3 epoch, the 1 iteration, current batch is [ 6 7 8 9 10 11]
53 The 3 epoch, the 2 iteration, current batch is [12 13 14 0 1 2]
54 The 3 epoch, the 3 iteration, current batch is [3 4 5 6 7 8]
如果tf.train.slice_input_producer(shuffle=True),输出为乱序,结果如下:

1 images.shape=(15, 5, 5, 3), labels.shape=(15,)
2 image_batch.shape=(6, 5, 5, 3), label_batch.shape=(6,)
3 The 1 epoch, the 1 iteration, current batch is [ 2 5 8 11 3 10]
4 The 1 epoch, the 2 iteration, current batch is [ 9 12 7 1 14 13]
5 The 1 epoch, the 3 iteration, current batch is [0 6 4 2 3 6]
6 The 2 epoch, the 1 iteration, current batch is [11 10 12 14 13 5]
7 The 2 epoch, the 2 iteration, current batch is [8 1 0 9 4 7]
8 The 2 epoch, the 3 iteration, current batch is [10 13 1 4 12 3]
9 The 3 epoch, the 1 iteration, current batch is [ 2 8 5 9 14 7]
10 The 3 epoch, the 2 iteration, current batch is [ 0 11 6 1 14 9]
11 The 3 epoch, the 3 iteration, current batch is [11 6 12 7 0 13]
如果tf.train.batch(allow_smaller_final_batch=True),则会返回不足批次数目的数据,结果如下:

1 images.shape=(15, 5, 5, 3), labels.shape=(15,)
2 The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
3 The 1 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11]
4 The 1 epoch, the 3 iteration, current batch is [12 13 14]
5 The 2 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
6 The 2 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11]
7 The 2 epoch, the 3 iteration, current batch is [12 13 14]
8 The 3 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
9 The 3 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11]
10 The 3 epoch, the 3 iteration, current batch is [12 13 14]
2.2 使用tf.data.Dataset类

1 import tensorflow as tf
2 import numpy as np
3 import math
4
5 # 生成样例数据集
6 def generate_data():
7 num = 15
8 labels = np.asarray(range(num))
9 images = np.random.random([num, 5, 5, 3])
10 return images, labels
11 # 打印样例信息
12 images, labels = generate_data()
13 print(‘images.shape={0}, labels.shape={1}’.format(images.shape, labels.shape))
14
15 # 定义周期、批次、数据总数、遍历一次所有数据需的迭代次数
16 n_epochs = 3
17 batch_size = 6
18 train_nums = 15
19 iterations = math.ceil(train_nums/batch_size)
20
21 # 使用from_tensor_slices将数据放入队列,使用batch和repeat划分数据批次,且让数据序列无限延续
22 dataset = tf.data.Dataset.from_tensor_slices((images, labels))
23 dataset = dataset.batch(batch_size).repeat()
24
25 # 使用生成器make_one_shot_iterator和get_next取数据
26 iterator = dataset.make_one_shot_iterator()
27 next_iterator = iterator.get_next()
28
29 with tf.Session() as sess:
30 for epoch in range(n_epochs):
31 for iteration in range(iterations):
32 cu_image_batch, cu_label_batch = sess.run(next_iterator)
33 print(‘The {0} epoch, the {1} iteration, current batch is {2}’.format(epoch+1,iteration+1,cu_label_batch))
34
35
36 # 结果如下:
37 images.shape=(15, 5, 5, 3), labels.shape=(15,)
38 The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
39 The 1 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11]
40 The 1 epoch, the 3 iteration, current batch is [12 13 14]
41 The 2 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
42 The 2 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11]
43 The 2 epoch, the 3 iteration, current batch is [12 13 14]
44 The 3 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
45 The 3 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11]
46 The 3 epoch, the 3 iteration, current batch is [12 13 14]
使用shuffle(),第23行修改为dataset = dataset.shuffle(100).batch(batch_size).repeat(),结果如下:

1 images.shape=(15, 5, 5, 3), labels.shape=(15,)
2 The 1 epoch, the 1 iteration, current batch is [ 7 4 10 8 3 11]
3 The 1 epoch, the 2 iteration, current batch is [ 0 2 12 13 14 5]
4 The 1 epoch, the 3 iteration, current batch is [6 9 1]
5 The 2 epoch, the 1 iteration, current batch is [ 6 14 7 9 3 8]
6 The 2 epoch, the 2 iteration, current batch is [13 5 12 1 11 2]
7 The 2 epoch, the 3 iteration, current batch is [ 0 4 10]
8 The 3 epoch, the 1 iteration, current batch is [10 8 13 12 3 14]
9 The 3 epoch, the 2 iteration, current batch is [ 6 9 2 5 1 11]
10 The 3 epoch, the 3 iteration, current batch is [0 4 7]
!!!
转载于:https://www.cnblogs.com/jfl-xx/p/9945967.html

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值