Dataloader
当我们使用pytorch进行深度学习模型开发时,为了使用自己的数据集训练网络,往往需要构建自己的dataloader。dataloader用于构建可迭代的数据装载器,可以使用for img, labels in dataloaders
进行可迭代对象的访问,从而提取数据用于训练与验证。
pytorch 循环这个 DataLoader 对象,将img, label加载到模型中进行训练。
自定义
自定义的LoadData三个方法是缺一不可的:
1.__init__()
,主要用来定义数据的预处理
2.__getitem__
方法,返回数据的img和label
3.__len__
方法,返回数据个数
pytorch会通过本文自定义的dataloader,从ch1,ch2,ch3三个文件夹按一定比例提取img和mask中的数据作为模型输入,从而实现多通道按比例输入不同label数据。
代码
import os
import numpy as np
import random
import cv2
from torch.utils import data
import glob
class My_data(data.Dataset):
def __init__(self, opt):
self.opt = opt
self.__counter = 0
self.load_data()
def load_data(self):
root_path = self.opt.data_root
name_img_png = R'/img/*.png'
name_mask_png = R'/mask/*.png'
temp_char = 'ch1'
self._channel_1_img = glob.glob(os.path.join(root_path, temp_char+name_img_png))
self._channel_1_mask = glob.glob(os.path.join(root_path, temp_char+name_mask_png))
self.check_error(self._channel_1_img, self._channel_1_fake, self._channel_1_mask, temp_char)
temp_char = 'ch2'
self._channel_2_img = glob.glob(os.path.join(root_path, temp_char + name_img_png))
self._channel_2_mask = glob.glob(os