Tensorflow读取数据(一)

数据算法是深度学习最重要的两大块。而更基础的首先是要熟练掌握一个框架来支撑算法的执行。
我个人使用最多的是tensorflow平台。就从最基础的数据输入开始记录吧。

AI算法基本流程

个人总结的AI项目基础流程(除开更复杂的工程化工作)
(1)数据预处理:get每个迭代的输入和标签。图像,音频,文本对数据处理方式又各有不同;不同的需求对标签的格式也不相同。
(2)算法建模:设计网络模型,输入:训练数据;输出:预测值
(3)优化参数:通过输出和真实label设计loss,还需要设计一个优化算法,让网络参数去学习得到最优解
(4)迭代训练:不断更新数据,在大数据上优化参数
(5)保存网络参数以及设计评价指标
以上步骤还只是算法部分,而且每个模块都有很可以展开出很多内容,其他更多工程上模块就不提了~

数据模块

今天先从数据模块下手。在训练过程中,我们对需求就是要不断的从所有数据中取一个batch数据输入到模型中。如果是python,那比较简单,伪代码如下:

#随机从datas里面抽取batch_size个数据
def get_batch(batch_size,datas):
batch_datas = []
datas.shuffle()
for i in range(batch_size):
	batch_datas.append(datas[i])
return batch_datas

但是在tensorflow框架中,我们就要利用它的优势来进行数据的读取。今天先介绍通过tf.Coordinatortf.QueueRunner来利用多线程管理数据。
tf.QueueRunner()就是负责开启线程以及线程队列
tf.train.Coordinator()就是创建一个线程管理器,管理我们开启的线程

准备数据

我们先准备两类图片数据,结构如下
在这里插入图片描述
为了方便,我们建立数据集文件夹Images,里面两类图片数据1,2。
然后我们生成一个文件列表,代码如下:

# -*- coding: utf-8 -*-
# @Time    : 2019-09-21 22:35
# @Author  : LanguageX
import os

root_dir = os.getcwd()
fw = open("./train.txt","w")
for root, dirs, files in os.walk(root_dir):
    for file in files:
        if file.endswith("jpg") or file.endswith("png"):
            filename = os.path.join(root, file)
            class_name = filename.split("/")[-2]
            print(class_name,filename)
            fw.write(filename+" "+class_name+"\n")

目的就是生成train.txt文本列表(格式:图片路径–类别)
在这里插入图片描述
数据准备好了~下面就可以开始实现取数据的代码了~

代码框架比较简单,添加了比较详细的注释,就直接上代码吧:

# -*- coding: utf-8 -*-
# @Time    : 2019-09-21 22:24
# @Author  : LanguageX

import tensorflow as tf
import os

class DataReader:

    def get_data_lines(self, filename):
        with open(filename) as txt_file:
            lines = txt_file.readlines()
            return lines

    def gen_datas(self, train_files):
        paths = []
        labels = []
        for line in train_files:
            line = line.replace("\n","")
            path, label = line.split(" ")
            paths.append(path)
            labels.append(label)
        return paths, labels

    def __init__(self,root_dir,train_filepath,batch_size,img_size):
         self.dir = root_dir
         self.batch_size = batch_size
         self.img_size = img_size
         #读取生成的path-label列表
         self.train_files = self.get_data_lines(train_filepath)
         #获取对应的paths和labels
         self.paths,self.labels = self.gen_datas(self.train_files)
         self.data_nums  = len(self.train_files)



    def get_batch(self, batch_size):
        self.paths = tf.cast(self.paths, tf.string)
        self.labels = tf.cast(self.labels, tf.string)
        #slice_input_producer构建了取数据队列
        input_queue = tf.train.slice_input_producer([self.paths, self.labels], num_epochs=10, shuffle=True)

        # 从文件名称队列中读取文件放入文件队列
        image_batch, label_batch= tf.train.batch(input_queue, batch_size=batch_size, num_threads=2, capacity=64,
                                                  allow_smaller_final_batch=False)

        return image_batch, label_batch



if __name__ == '__main__':
    root_dir = "../images/"
    filename = "./images/train.txt"
    batch_size = 4
    image_size = 256
    dataset = DataReader(root_dir,filename,batch_size,image_size)

    images,labels = dataset.get_batch(batch_size)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        #coord线程管理器
        coord = tf.train.Coordinator()
        #tf的线程队列
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(5):
            _imgs,_labesl = sess.run([images,labels])
            print("_imgs ", _imgs)
            print("_labes ", _labesl)
        #通知线程停止
        coord.request_stop()
        coord.join(threads)
        sess.close()

运行就可以在每个迭代获取到batch_size个数据了。基本本文获取数据的基本框架,其他任务的数据读取都可以举一反三添加业务需求了~

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值