ResNet迁移学习(二)—— 数据读取(DataLoader.py)

基本简介

深度学习的显著特点是大量数据处理,如何充分利用CPU和GPU是影响训练时间的一个重要因素。不同的任务场景,可能数据量的大小也呈现不同的规模,比如ImageNet百万级数据,COCO万级规模,我们此次的实际场景中,数据采集较为困难,也就是千级规模。

代码展示

  下面代码是DataLoader.py,主要内容是数据读取数据增强,以及训练集和验证集多线程读取。其中多线程取数据的过程中,需要注意避免不同的线程读取重复的数据。由于队列具有先进先出,每次取一个数据的特性,所以解决办法是建立一个名字队列(train_name_queue)。所有的线程都从同一个队列中取出训练数据。
  代码的处理流程如下图所示:
在这里插入图片描述

import random
from PIL import Image, ImageEnhance
from cv_rotation import *
from config import cfg
from multiprocessing import Queue, Process


class DataLoader:
    def __init__(self, file):
        self.input = cfg.Train.Input_Size
        self.root_path = cfg.Train.Root_Path

        # read text file: save train name list
        self.name_list = []
        data = open(file, 'r')
        for line in data:
            line = line.strip()
            s1 = line.split('/')
            if s1[0] != '0803ply':
                self.name_list.append(line)
        random.shuffle(self.name_list)

    def name_queue_(self, name_queue):
        count = 0
        random.shuffle(self.name_list)
        while True:
            if count >= len(self.name_list):
                count = 0
                random.shuffle(self.name_list)
                continue

            name_queue.put(self.name_list[count])
            # print(self.name_list[count])
            count = count + 1
            # if name_queue.full():
            #     print('队列满')
            #     print('count: ', count)

    def image_enhance(self, img):
        p = random.randint(1, 3)
        a1 = random.uniform(0.8, 2)
        a2 = random.uniform(0.8, 1.4)
        a3 = random.uniform(0.8, 1.7)
        a4 = random.uniform(0.8, 2.5)
        img = Image.fromarray(img)

        img = ImageEnhance.Color(img).enhance(a1) if p == 0 else img
        img = ImageEnhance.Brightness(img).enhance(a2) if p == 1 else img
        img = ImageEnhance.Contrast(img).enhance(a3) if p == 2 else img
        img = ImageEnhance.Sharpness(img).enhance(a4) if p == 3 else img
        img = np.array(img)

        return img

    def flip_img(self, img):
        flipped = (np.random.random() < 0.5)

        if flipped:
            img = img[:, ::-1, :]

        return img

    @staticmethod
    def show_image(name, data):
        cv2.imshow(name, data)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

    def pose_rotation(self, img):
        w, h, c = img.shape
        deg = random.uniform(-15.0, 15.0)
        M_rotate = affine_rotation_matrix(angle=deg)
        transform_matrix = transform_matrix_offset_center(M_rotate, x=w, y=h)

        img_result = affine_transform_cv2(img, transform_matrix)

        return img_result

    def load_data(self, batch_size, queue, thread, name_queue, mode):
        """
        从名字队列中逐个读取训练数据,按一个batch存储
        :param batch_size: 批次大小
        :param queue: 存储数据的队列
        :param thread: 分配的线程数
        :param name_queue: 训练数据集的名字队列
        :param mode: train or valid mode, 对应训练集和验证集上不同的处理方式
        :return: 存有数据的queue
        """
        image = []
        label = []
        data_name = []
        thread_name = []

        sign_0 = 0
        sign_1 = 0

        while 1:
            data = name_queue.get()
            # print('data: ', data)
            d1 = data.split(' ')

            # 读取数据:解析数据集名字和标签
            if len(d1) == 2:
                data_image = self.root_path + d1[0]
                if float(d1[-1]) >= 10:
                    data_label = 1
                    sign_1 = sign_1 + 1

                # elif float(d1[-1]) >= 10:
                #     continue
                else:
                    data_label = 0
                    sign_0 = sign_0 + 1

            else:
                ss = ' '.join(d1[:-1])
                data_image = self.root_path + ss

                if float(d1[-1]) >= 10:
                    data_label = 1
                    sign_1 = sign_1 + 1
                # elif float(d1[-1]) >= 10:
                #     continue
                else:
                    data_label = 0
                    sign_0 = sign_0 + 1

            img = cv2.imread(data_image)
            # human_data.show_image('ori image', img)

            # 数据增强
            if mode == 'train':
                # img = self.image_enhance(img)
                # self.show_image('enhance', img)
                img = self.flip_img(img)
                # self.show_image('flip', img)
                # img = self.pose_rotation(img)

            img = cv2.resize(img, (self.input[1], self.input[0]))
            # human_data.show_image('resize', img)
			
			# 数据归一化
            img = img.astype(np.float32)
            # img = (img - np.mean(img, axis=(0, 1))) / (np.std(img, axis=(0, 1)) + 1e-8)
            img = img - np.array([0, 109.32738873, 109.89168176]).reshape(1, 3)
			
            if mode == 'train-no-balance':
                if (sign_0 <= 16) and (float(d1[-1]) <= 10):
                    data_name.append(data_image)
                    image.append(img)
                    label.append(data_label)

                if (sign_1 <= 16) and (float(d1[-1]) > 10):
                    data_name.append(data_image)
                    image.append(img)
                    label.append(data_label)
            else:
                data_name.append(data_image)
                image.append(img)
                thread_name.append(thread)
                label.append(data_label)
                
            if len(image) != batch_size:
                continue
			# 每次读取一个批次数据,放入队列
            queue.put([data_name, thread_name, np.array(image), np.array(label)])
            # print('名字: ', data_name)
            # print('线程: ', thread_name)

            image = []
            label = []
            data_name = []
            thread_name = []


def train_set_queue():
    """
    读取训练集数据训练,主要分为三个部分,如下所示:
        1.读取训练集名字:从文本中读取所有训练集名字,存入列表.
        2.建立队列,动态读取训练集名字至队列:分配单独的线程, 读取所有训练集名字至train_name_queue.
        3.建立训练集队列, 读取数据;建立多个线程, 同时从train_name_queue中获取名字,并根据名字从硬盘读数据至train_queue.
    :return: 存储训练集数据的队列
    """
    train_file = cfg.Train.Train_Set
    human_data_train = DataLoader(train_file)
    print("num of train data: ", len(human_data_train.name_list))

    # 单线程读取训练集名字
    train_name_queue = Queue(cfg.Train.Train_Num)
    name_process = Process(target=human_data_train.name_queue_, args=(train_name_queue, ))
    name_process.start()

    # create queue and read train data
    cache_train_data = 200
    train_thread_num = 2

    train_queue = Queue(cache_train_data)
    for thread in range(train_thread_num):
        p_train = Process(target=human_data_train.load_data,
                          args=(cfg.Train.Batch_Size, train_queue, thread, train_name_queue, 'train'))
        p_train.start()
    return train_queue


def valid_set_queue():
    """
    处理流程与train_set_queue()函数一样.
    :return: 存储验证集的名字.
    """
    valid_file = cfg.Train.Valid_Set
    human_data_valid = DataLoader(valid_file)
    print("num of valid data: ", len(human_data_valid.name_list))

    # 单独队列,读取验证集名字
    valid_name_queue = Queue(cfg.Train.Valid_Num)
    valid_name_process = Process(target=human_data_valid.name_queue_, args=(valid_name_queue,))
    valid_name_process.start()

    # create queue and read valid data
    cache_valid_data = 128
    valid_thread_num = 1

    valid_queue = Queue(cache_valid_data)
    for thread in range(valid_thread_num):
        p_valid = Process(target=human_data_valid.load_data,
                          args=(cfg.Train.Batch_Size, valid_queue, thread, valid_name_queue, 'valid'))
        p_valid.start()
    return valid_queue

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值