Tensor Flow shuffle_batch 的方式读csv文件的例子

用最简单的代码展示了一个tensor flow shuffle的方式读文件

代码

#coding=utf-8                                                                                                                                                                                                                                                                 

import tensorflow as tf
import numpy as np

def readMyFileFormat(fileNameQueue):
    reader = tf.TextLineReader()
    key, value = reader.read(fileNameQueue)

   record_defaults = [[1], [1], [1]]
    col1, col2, col3 = tf.decode_csv(value, record_defaults = record_defaults)
    features = tf.pack([col1, col2])
    label = col3
    return features, label

def inputPipeLine(fileNames = ["file0.csv", "file1.csv"], batchSize = 4, numEpochs = None):
    fileNameQueue = tf.train.string_input_producer(fileNames, num_epochs = numEpochs)
    example, label = readMyFileFormat(fileNameQueue)
    min_after_dequeue = 8
    capacity = min_after_dequeue + 3 * batchSize
    exampleBatch, labelBatch = tf.train.shuffle_batch([example, label], batch_size = batchSize, num_threads = 3,  capacity = capacity, min_after_dequeue = min_after_dequeue)
    return exampleBatch, labelBatch

featureBatch, labelBatch = inputPipeLine(["file0.csv", "file1.csv"], batchSize = 4)
with tf.Session() as sess:
    # Start populating the filename queue.                                                                                                                                                                                                                                    
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

   # Retrieve a single instance:                                                                                                                                                                                                                                             
    try:
        #while not coord.should_stop():                                                                                                                                                                                                                                       
        while True:
            example, label = sess.run([featureBatch, labelBatch])
            print example
    except tf.errors.OutOfRangeError:
        print 'Done reading'
    finally:
        coord.request_stop()

   coord.join(threads)
    sess.close()

file0.csv 的内容

9,1,1
10,2,3
11,3,1
12,4,2

file1.csv 的内容

1,1,7
2,2,8
3,3,5
4,4,9
5,5,5
6,6,1
7,7,2
8,8,4
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值