TensorFlow 组合训练数据(batching)

在之前的文章中我们提到了TensorFlow TensorFlow 队列与多线程的应用以及TensorFlow TFRecord数据集的生成与显示,通过这些操作我们可以得到自己的TFRecord文件,并从其中解析出单个的Image和Label作为训练数据提供给网络模型使用,而在实际的网络训练过程中,往往不是使用单个数据提供给模型训练,而是使用一个数据集(mini-batch),mini-batch中的数据个数称为batch-size。mini-batch的思想能够有效的提高模型预测的准确率。大部分的内容和之前的操作是相同的,数据队列中存放的还是单个的数据和标签,只是在最后的部分将出队的数据组合成为batch使用,下面给出从原始数据到batch的整个流程:
这里写图片描述
可以看到,截止到生成单个数据队列操作,和之前并没有什么区别,关键之处在于最后batch的组合,一般来说单个数据队列的长度(capacity)和batch_size有关:
capacity = min_dequeue+3*batch_size
我是这样理解第二个队列的:入队的数据就是解析出来的单个的数据,而出队的数据组合成了batch,一般来说入队数据和出队数组应该是相同的,但是在第二个队列中不是这样。

那么在TensorFlow中如何实现数据的组合呢,其实就是一个函数:
tf.train.batch
或者
tf.train.shuffle_batch
这两个函数都会生成一个队列,入队的数据是单个的Image和Label,而出队的是一个batch,也已称之为一个样例(example)。他们唯一的区别是是否将数据顺序打乱。

本文以tf.train.batch为例,定义如下:

def batch(
tensors, //张量
batch_size, //个数
num_threads=1, //线程数
capacity=32,//队列长度
enqueue_many=False, 
shapes=None, 
dynamic_pad=False,
allow_smaller_final_batch=False, 
shared_name=None, 
name=None):

下面写一个代码测试一下,工程目录下有一个TFRecord数据集文件,该代码主要做以下工作,从TFRecord中读取单个数据,每四个数据组成一个batch,一共生成10个batch,将40张图片写入指定路径下,命名规则为batch?size?Label?,batch和size决定了是第几个组合中的第几个图,label决定数据的标签。

import os 
import tensorflow as tf 
from PIL import Image  
import matplotlib.pyplot as plt 
import numpy as np

#路径
swd = 'F:\\testdata\\show\\'
filename_queue = tf.train.string_input_producer(["mydata.tfrecords"]) #读入流中
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)   #返回文件名和文件
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'img_raw' : tf.FixedLenFeature([], tf.string),
                                   })  #取出包含image和label的feature对象
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [36,136,3])
label = tf.cast(features['label'], tf.int32)

#组合batch
batch_size = 4
mini_after_dequeue = 100
capacity = mini_after_dequeue+3*batch_size

example_batch,label_batch = tf.train.batch([image,label],batch_size = batch_size,capacity=capacity)

with tf.Session() as sess: #开始一个会话
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    coord=tf.train.Coordinator()
    threads= tf.train.start_queue_runners(sess = sess,coord=coord)
    for i in range(10):#10个batch
        example, l = sess.run([example_batch,label_batch])#取出一个batch
        for j in range(batch_size):#每个batch内4张图
            sigle_image = Image.fromarray(example[j], 'RGB')
            sigle_label = l[j]
            sigle_image.save(swd+'batch_'+str(i)+'_'+'size'+str(j)+'_'+'Label_'+str(sigle_label)+'.jpg')#存下图片
            print(example, l)
        
    coord.request_stop()
    coord.join(threads)

这里写图片描述

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值