[Tensorflow]番外一 读取csv文件

在我们训练神经网络时,经常喜欢把训练数据存储成csv的格式,因为csv的格式基本上可以说就是一种纯文本格式,在不同的操作系统上的兼容性非常好。Tensorflow对csv文件有非常好的支持,在此就给大家介绍一种基本的读取方法。

Tensorflow读取csv文件的方法如下:

第一步 使用TextLineReader对象的read方法将csv作为文本文件逐行读取行进来。 若在创建对象时将skip_header_line参数设为1,则读取时将会略过第一行(标题行)。read方法的参数是一个文件名队列,就是说你可以让它依次读取多个csv文件。如果你只想读取一个文件,那么文件队列里只放一个文件就可以了。read方法的返回值有两个,第一个是行号,第二个是读取到的内容。

reader = tf.TextLineReader(skip_header_lines=1)
key, value = reader.read(file_queue)  

第二步 使用decode_csv函数解析逐行解析读取进来的内容。decode_csv的第一个参数就是需要解析的内容,record_defaults参数通过一个1维向量来描述内容的格式和数据类型。如以下例子,每行解析出6个字段的数据,每个分量的数据类型代表了输出的数据类型。1代表int型,1.0代表float型,而‘null’则是string型。

Id, Sepal_Length,Sepal_Width,Petal_Length,Petal_Width,label = tf.decode_csv(value,
record_defaults=[[1],[1.0],[1.0],[1.0],[1.0],['null']])

第三步 将读取进来的内容使用train对象的batch或shuffle_batch成批输出。batch方法之前已介绍过,这里我们来讲讲shuffle_batch。这个方法可以乱序成批输出数据,其基本用法和batch一样,但是多了几个基本参数。

capacity: 表示队列里允许的最大数据量

min_after_dequeue:表示数据从队列里出列后剩余的最小数据量。

为什么shuffle_batch方法需要增加这两个参数, 我猜想是因为乱序输出需要数据队列里存有的数据量大于输出的数据量,因此才会要定义队列的最大和最小数据量。

get_data, get_label = tf.train.shuffle_batch(
         [[Sepal_Length,Sepal_Width,Petal_Length,Petal_Width],label ], 
         batch_size = batch_size,capacity=150, min_after_dequeue=10)

这个过程基本是可以理解为读一行,解析一行,然后再成批打包输出到神经网络系统。其中,尤其需要注意的一点是TextLineReader对象的read方法返回的value值是一个tensor操作对象,而不是实际的数据。因此如果你不使用tensorflow的方法,而是直接解析value的内容必然会报错。

要让这个队列跑起来,你还需要创建一个Session对象,然后启动队列,调用Session对象的run方法来获得实际的数据。这个过程在  第一课 创建一个数据队列 里有详细的介绍,不熟悉的小伙伴可以参考那篇文章。

最后给出完整的代码,以供参考。

# -*- coding: utf-8 -*-
import tensorflow as tf

def get_batch_data(file_queue, batch_size):    
    reader = tf.TextLineReader(skip_header_lines=1)
    key, value = reader.read(file_queue)   
    Id, Sepal_Length,Sepal_Width,Petal_Length,Petal_Width,label = 
    get_data, get_label = tf.train.shuffle_batch(
         [[Sepal_Length,Sepal_Width,Petal_Length,Petal_Width],label ], 
         batch_size = batch_size,capacity=150, min_after_dequeue=10)
    return get_data, get_label
   
with tf.Graph().as_default() as g: 
    #生成训练数据文件队列,此处我们只有一个训练数据文件
    train_file_queue = tf.train.string_input_producer(['iris.csv'], num_epochs=None)   
    #从训练数据文件列表中获取数据和标签
    data,label = get_batch_data(train_file_queue, TRAIN_BATCH_SIZE)
    with tf.Session() as sess:
        #创建一个Coordinator用于训练结束后关闭队列
        coord = tf.train.Coordinator()
        #启动队列
        threads = tf.train.start_queue_runners(sess, coord)
        try:
            data_value, label_value = sess.run([data,label])           
        except tf.errors.OutOfRangeError: 
            print("Done")
        finally:
            coord.request_stop()
        coord.join(threads)   
        
    

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

编程小白的逆袭日记

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值