Tensorflow框架 —— 训练数据读取的几种方式

1. 概述

深度学习训练的数据往往非常巨大,如何保证高效的加载数据,也是提升训练速度的关键。本文主要介绍如下几种训练数据加载方式:

  • 迭代器(iter() and next())
  • 多线程+队列
  • TFRecord

2. 迭代器

class IterTest():
    def __init__(self, data=1):
        self.data = data

    def __iter__(self):
 		# 表明 IterTest是可迭代类
        return self
        
    def __next__(self):
    	# 迭代器的具体实现
        if self.data > 5:
            raise StopIteration
        else:
            self.data += 1
            return self.data

for epoch in range(4):
    s = IterTest(3)
    for item in s:
        print(item)
# output:
	4
	5
	6
	==
	4
	5
	6
	==
	4
	5
	6
	==
	4
	5
	6
	==

for … in …: 该循环实现两件事,第一件事是获得一个可迭代器,即调用了__iter__()函数;第二件事是循环的过程,循环调用__next__()函数。
对于 IterTest 类来说,它定义了__iter__和__next__函数,所以是一个可迭代的类,也可以说是一个可迭代的对象(Python中一切皆对象)。

参考链接:https://blog.csdn.net/liweibin1994/article/details/77374854

3. 多线程+队列

通常神经网络的训练过程中,CPU由于加载数据和数据的预处理,GPU通常用于梯度计算和参数更新。鉴于CPU和GPU处理数据上的速度差别很大。GPU可以快速处理每一批次(batch)的数据,但是CPU的速度往往是无法与GPU的速度匹配,导致单位时间内GPU处于空闲状态。为了使的GPU处于最大利用状态,考虑使用多线程+队列的方式提升CPU处理数据的时间。

基本的处理流程如下:

在这里插入图片描述

代码实现

  数据处理类,用于CPU读取数据,数据增强等操作。

import random
import numpy as np
import cv2
from PIL import Image, ImageEnhance
from cv_rotation import *


class DataLoader:
    def __init__(self, file):
        # read text file: save train name list
        self.name_list = []

        data = open(file, 'r')
        for line in data:
            line = line.strip()
            self.name_list.append(line)
        random.shuffle(self.name_list)
	
	# 读取名字列表的具体实现,即target
    def name_queue_(self, name_queue):
        count = 0
        random.shuffle(self.name_list)
        while True:
            if count >= len(self.name_list):
                count = 0
                random.shuffle(self.name_list)
                continue

            name_queue.put(self.name_list[count])
            # print(self.name_list[count])
            count = count + 1
            # if name_queue.full():
            #     print('队列满')
            #     print('count: ', count)
	
	# 数据颜色类特征增强
    def image_enhance(self, img):
        p = random.randint(1, 3)
        a1 = random.uniform(0.8, 2)
        a2 = random.uniform(0.8, 1.4)
        a3 = random.uniform(0.8, 1.7)
        a4 = random.uniform(0.8, 2.5)
        img = Image.fromarray(img)

        img = ImageEnhance.Color(img).enhance(a1) if p == 0 else img
        img = ImageEnhance.Brightness(img).enhance(a2) if p == 1 else img
        img = ImageEnhance.Contrast(img).enhance(a3) if p == 2 else img
        img = ImageEnhance.Sharpness(img).enhance(a4) if p == 3 else img
        img = np.array(img)

        return img
	
	# 图像翻转
    def flip_img(self, img):
        flipped = (np.random.random() < 0.5)

        if flipped:
            img = img[:, ::-1, :]

        return img

    @staticmethod
    def show_image(name, data):
        cv2.imshow(name, data)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
	
	# 图像随机旋转
    def pose_rotation(self, img):
        w, h, c = img.shape
        deg = random.uniform(-15.0, 15.0)
        M_rotate = affine_rotation_matrix(angle=deg)
        transform_matrix = transform_matrix_offset_center(M_rotate, x=w, y=h)

        img_result = affine_transform_cv2(img, transform_matrix)

        return img_result
	
	# 读取数据,每次存储一个batch,用于直接传入GPU训练
    def load_data(self, human_data, batch, queue, thread, name_queue):
        image = []
        label = []
        data_name = []
        thread_name = []
		
		# 无限循环读取数据
        while 1:
        	# 从名字列表队列中取数据,数据的名字
            data = name_queue.get()
            # print('data: ', data)
            d1 = data.split(' ')
			
			# 根据获取的名字,opencv读取图像,标签
			# 数据增强
            if len(d1) == 2:
                data_image = '../roc_0716/' + d1[0] + '/correction.roc_0.bmp'
                if float(d1[-1]) > 10:
                    data_label = 1
                # elif float(d1[-1]) > 12:
                #     continue
                else:
                    data_label = 0
            else:
                ss = ' '.join(d1[:-1])
                # print(ss)
                data_image = '../roc_0716/' + ss + '/correction.roc_0.bmp'

                if float(d1[-1]) > 10:
                    data_label = 1
                # elif float(d1[-1]) > 12:
                #     continue
                else:
                    data_label = 0

            img = cv2.imread(data_image)
            # human_data.show_image('ori image', img)

            # 数据增强
            img = human_data.image_enhance(img)
            # human_data.show_image('enhance', img)
            img = human_data.flip_img(img)
            # human_data.show_image('flip', img)

            img = human_data.pose_rotation(img)
            img = cv2.resize(img, (320, 480))
            # human_data.show_image('resize', img)
			
			# 归一化,数据处理的常见操作
            img = img.astype(np.float32)
            img = (img - np.mean(img, axis=(0, 1))) / (np.std(img, axis=(0, 1)) + 1e-8)
			
			# 将处理完的数据放入列表
            data_name.append(data_image)
            # thread_name.append(thread)
            image.append(img)
            label.append(data_label)
			
			# 每次读满一个Batch才会存储数据到队列
            if len(image) != batch:
                continue
			
            queue.put([data_name, thread_name, np.array(image), np.array(label)])
            # print('name+++: ', data_image)
            # print('thread++: ', thread)

            image = []
            label = []
            data_name = []
            thread_name = []

  下面的代码片段为多线程的初始化,以及分配队列长度等。线程具体分配多少?这个需要参考电脑的线程的数量,不能占用全部的线程。队列的长度影响数据的预存储数量,也间接决定GPU的效率。所以,队列的长度除了考虑电脑的内存大小,还要考虑GPU的使用情况。那么,队列长度和线程数量分配的标准是保证GPU满负荷状态训练。

def main():
    tf.set_random_seed(-1)

    # ****************************************************************** #
    #                   1. Python多线程数据读取与数据增强                   #
    # ****************************************************************** #
    train_file = "./data/train.txt"  # 保存训练集的名字列表
    human_data_train = DataLoader(train_file)
    print("num of train data: ", len(human_data_train.name_list))

    # 单线程读取,存入队列,读取训练集名字
    train_name_queue = Queue(cfg.Train.Train_Num)  # len(human_data_train.name_list)
    name_process = Process(target=human_data_train.name_queue_, args=(train_name_queue, ))
    name_process.start()

    # # create queue and read train data
    cache_train_data = 300
    train_thread_num = 3
	
	# 初始化训练集存储队列
    q = Queue(cache_train_data)
    for thread in range(train_thread_num):
    	# target: 队列读取数据的方法或者实现
    	# args: 方法或者实现的参数
        p_train = Process(target=human_data_train.load_data,
                          args=(human_data_train, cfg.Train.Batch_Size, q, thread, train_name_queue))
        p_train.start()

    # load valid data
    valid_file = "./data/valid.txt"  # 保存验证集的名字列表
    human_data_valid = DataLoader(valid_file)
    print("num of valid data: ", len(human_data_valid.name_list))

    # 单线程读取数据并存入队列,读取验证集名字
    valid_name_queue = Queue(cfg.Train.Valid_Num)  # len(human_data_valid.name_list)
    valid_name_process = Process(target=human_data_valid.name_queue_, args=(valid_name_queue,))
    valid_name_process.start()

    # create queue and read valid data
    cache_valid_data = 50
    valid_thread_num = 1
	
	# 初始化存储验证集数据的队列
    valid_queue = Queue(cache_valid_data)
    for thread in range(valid_thread_num):
        p_valid = Process(target=human_data_train.load_data,
                          args=(human_data_valid, cfg.Train.Batch_Size, valid_queue, thread, valid_name_queue))
        p_valid.start()
	
	# 模型加载
	model = Model()
	... ...
	
	# 开始训练
	with tf.Session() as sess:
		# train process
		 _, _, image, label = q.get()
		sess.run([train_op], feed_dict={input:image, label_c:label})
		
		# valid process
		_, _, valid_image, valid_label = valid_queue.get()
		sess.run([train_op], feed_dict={valid_image, label_c:valid_label })
		... ...
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值