PyTorch同时读取两个数据集实现半监督学习

PyTorch同时读取两个数据集实现半监督学习

写在开头

本文是在kaggle上做的实验,所以直接从上面导了出来。后期应该还会更新,因为没写完。。。

https://www.kaggle.com/lartpang/segmentationdataloader

修改记录

  • 2020年5月7日:修改文中注释的小错误。

数据路径

UNLABELED_PATH = ["/kaggle/input/ecssd/ECSSD/Image", "/kaggle/input/ecssd/ECSSD/Mask"]
LABELED_PATH = ["/kaggle/input/pascal-s/Pascal-S/Image", "/kaggle/input/pascal-s/Pascal-S/Mask"]

TODO

  • 读取DUTS-TR和MixFlickrDUS用于训练
  • 每个batch都要保证包含1/4的DUTS-TR的数据集和3/4的MixFlickrDUTS
  • 针对训练集使用不同的增强方式
  • 尝试更多的方法

不考虑测试集,因为测试集完全可以使用一个独立的ImageFolder类构造。

方法一:通过对__getitem__的索引进行计算,按照比例关系选择对应数据集的数据

if index % (self.r_l_rate + 1) == 0:
    label_index = index // (self.r_l_rate + 1)
    img_path, gt_path = self.imgs_label[label_index]  # 0, 1 => 10550
else:
    unlabel_index = index // (self.r_l_rate + 1) + index % (self.r_l_rate + 1)
    img_path, gt_path = self.imgs_unlabel[unlabel_index]  # 1, 2, 3

主体代码:

import os

import torch.utils.data as data
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader
import math


class JointResize(object):
    def __init__(self, size):
        if isinstance(size, int):
            self.size = (size, size)
        elif isinstance(size, tuple):
            self.size = size
        else:
            raise RuntimeError("size参数请设置为int或者tuple")

    def __call__(self, img, mask):
        img = img.resize(self.size)
        mask = mask.resize(self.size)
        return img, mask

def make_dataset(root, prefix=('jpg', 'png')):
    img_path = root[0]
    gt_path = root[1]
    img_list = [os.path.splitext(f)[0] for f in os.listdir(img_path) if f.endswith(prefix[0])]
    return [(os.path.join(img_path, img_name + prefix[0]), os.path.join(gt_path, img_name + prefix[1])) for img_name in img_list]


# 仅针对训练集
class ImageFolder(data.Dataset):
    def __init__(self, root, mode, in_size, prefix, use_bigt=False, split_rate=(1, 3)):
        """split_rate = label:unlabel"""
        assert isinstance(mode, str), 'isTrain参数错误,应该为bool类型'
        self.root_labeled = root[0]
        self.mode = mode
        self.use_bigt = use_bigt
        
        self.imgs_labeled = make_dataset(self.root_labeled, prefix=prefix)
        self.split_rate = split_rate
        self.r_l_rate = split_rate[1] // split_rate[0]
        len_labeled = len(self.imgs_labeled)

        self.root_unlabeled = root[1]
        self.imgs_unlabeled = make_dataset(self.root_unlabeled, prefix=prefix)
        len_unlabeled = len(self.imgs_unlabeled)

        len_unlabeled = self.r_l_rate * len_labeled
        self.imgs_unlabeled = self.imgs_unlabeled * (self.r_l_rate + math.ceil(len_labeled / len_unlabeled))  # 扩展无标签的数据列表
        self.imgs_unlabeled = self.imgs_unlabeled[0:len_unlabeled]

        self.length = len_labeled + len_unlabeled
        print(f"使用扩充比例为:{len(self.imgs_labeled) / len(self.imgs_unlabeled)}")

        # 仅是为了简单而仅使用一种变换
        self.train_joint_transform = JointResize(in_size)
        self.train_img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 处理的是Tensor
        ])
        # ToTensor 操作会将 PIL.Image 或形状为 H×W×D,数值范围为 [0, 255] 的 np.ndarray 转换为形状为 D×H×W,
        # 数值范围为 [0.0, 1.0] 的 torch.Tensor。
        self.train_gt_transform = transforms.ToTensor()

    def __getitem__(self, index):
        if index % (self.r_l_rate + 1) == 0:
            labeled_index = index // (self.r_l_rate + 1)
            img_path, gt_path = self.imgs_labeled[labeled_index]  # 0, 1 => 10550
        else:
            unlabeled_index = index // (self.r_l_rate + 1) + index % (self.r_l_rate + 1)
            img_path, gt_path = self.imgs_unlabeled[unlabeled_index]  # 1, 2, 3

        img = Image.open(img_path).convert('RGB')
        img_name = (img_path.split(os.sep)[-1]).split('.')[0]

        gt = Image.open(gt_path).convert('L')
        img, gt = self.train_joint_transform(img, gt)
        img = self.train_img_transform(img)
        gt = self.train_gt_transform(gt)
        if self.use_bigt:
            gt = gt.ge(0.5).float()  # 二值化
        return img, gt, img_name  # 输出名字方便比较

    def __len__(self):
        return self.length
    
print(f" ==>> 使用的训练集 <<==\n -->> LABELED_PATH:{LABELED_PATH}\n -->> UNLABELED_PATH:{UNLABELED_PATH}")
train_set = ImageFolder((LABELED_PATH, UNLABELED_PATH), "train", 320, prefix=('.jpg', '.png'), use_bigt=True, split_rate=(12, 36))
# 由于train_set内部的比例顺序是固定的,所以为了保持比例关系,不能再使用`shuffle=True`
train_loader = DataLoader(train_set, batch_size=48, num_workers=8, shuffle=False, drop_last=True, pin_memory=True)  

def split_data(data):
    labeled_data = []
    unlabeled_data = []
    for i, item in enumerate(data):
        if i % 4 == 0:
            labeled_data.append(item)
        else:
            unlabeled_data.append(item)
    return labeled_data, unlabeled_data

for train_idx, train_data in enumerate(train_loader):
    train_inputs, train_gts, train_names = train_data
    print(train_names)
    
    # 正常训练中下面应该有,这里为了方便就关掉了
    # train_inputs = train_inputs.to(self.dev)
    # train_gts = train_gts.to(self.dev)
    train_labeled_inputs, train_unlabeled_inputs = split_data(train_inputs)
    train_labeled_gts, _ = split_data(train_gts)
    train_labeled_names, train_unlabeled_names = split_data(train_names)
    print("labeled_names ", train_labeled_names)
    print("unlabeled_names ", train_unlabeled_names)

    # otr_total = self.net(train_inputs)
    # labeled_otr, unlabeled_otr = otr_total.split((12, 36), dim=0)
    # with torch.no_grad():
    #     ema_unlabeled_otr = ema_model(train_unlabeled_inputs)
    print(" ==>> 一个Batch结束了 <<== ")
    if train_idx == 2:
        break
print(" ==>> 一个Epoch结束了 <<== ")
 ==>> 使用的训练集 <<==
 -->> LABELED_PATH:['/kaggle/input/pascal-s/Pascal-S/Image', '/kaggle/input/pascal-s/Pascal-S/Mask']
 -->> UNLABELED_PATH:['/kaggle/input/ecssd/ECSSD/Image', '/kaggle/input/ecssd/ECSSD/Mask']
使用扩充比例为:0.3333333333333333
['623', '0733', '0106', '0375', '764', '0106', '0375', '0285', '771', '0375', '0285', '0591', '208', '0285', '0591', '0799', '820', '0591', '0799', '0074', '473', '0799', '0074', '0077', '333', '0074', '0077', '0498', '537', '0077', '0498', '0610', '45', '0498', '0610', '0617', '369', '0610', '0617', '0426', '56', '0617', '0426', '0989', '654', '0426', '0989', '0235']
labeled_names  ['623', '764', '771', '208', '820', '473', '333', '537', '45', '369', '56', '654']
unlabeled_names  ['0733', '0106', '0375', '0106', '0375', '0285', '0375', '0285', '0591', '0285', '0591', '0799', '0591', '0799', '0074', '0799', '0074', '0077', '0074', '0077', '0498', '0077', '0498', '0610', '0498', '0610', '0617', '0610', '0617', '0426', '0617', '0426', '0989', '0426', '0989', '0235']
 ==>> 一个Batch结束了 <<== 
['89', '0989', '0235', '0273', '20', '0235', '0273', '0307', '275', '0273', '0307', '0444', '785', '0307', '0444', '0058', '212', '0444', '0058', '0748', '239', '0058', '0748', '0255', '792', '0748', '0255', '0128', '58', '0255', '0128', '0700', '150', '0128', '0700', '0364', '6', '0700', '0364', '0798', '109', '0364', '0798', '0246', '149', '0798', '0246', '0337']
labeled_names  ['89', '20', '275', '785', '212', '239', '792', '58', '150', '6', '109', '149']
unlabeled_names  ['0989', '0235', '0273', '0235', '0273', '0307', '0273', '0307', '0444', '0307', '0444', '0058', '0444', '0058', '0748', '0058', '0748', '0255', '0748', '0255', '0128', '0255', '0128', '0700', '0128', '0700', '0364', '0700', '0364', '0798', '0364', '0798', '0246', '0798', '0246', '0337']
 ==>> 一个Batch结束了 <<== 
['187', '0246', '0337', '0208', '521', '0337', '0208', '0680', '436', '0208', '0680', '0834', '76', '0680', '0834', '0861', '539', '0834', '0861', '0141', '355', '0861', '0141', '0742', '516', '0141', '0742', '0781', '71', '0742', '0781', '0474', '708', '0781', '0474', '0372', '474', '0474', '0372', '0933', '501', '0372', '0933', '0970', '815', '0933', '0970', '0327']
labeled_names  ['187', '521', '436', '76', '539', '355', '516', '71', '708', '474', '501', '815']
unlabeled_names  ['0246', '0337', '0208', '0337', '0208', '0680', '0208', '0680', '0834', '0680', '0834', '0861', '0834', '0861', '0141', '0861', '0141', '0742', '0141', '0742', '0781', '0742', '0781', '0474', '0781', '0474', '0372', '0474', '0372', '0933', '0372', '0933', '0970', '0933', '0970', '0327']
 ==>> 一个Batch结束了 <<== 
 ==>> 一个Epoch结束了 <<== 

方法二:直接在__getitem__中一次性读取最简化比例数量的样本

上面的用法虽然简单,直接在一个ImageFolder中对数据进行组合,但是这样会导致一个问题,训练的时候无法使用shuffle=True设定,对于训练并不完美。

除了这里的设置方式,还有一种值得参考:在PoolNet的设置中,是直接对于每次迭代按照1:1的比例输入,所以其在__getitem__中直接同时imread两个数据集的图像。虽然这样比较简单,但是却也是直接有效。

下面仿写一份。

import os

import torch.utils.data as data
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
import math


class JointResize(object):
    def __init__(self, size):
        if isinstance(size, int):
            self.size = (size, size)
        elif isinstance(size, tuple):
            self.size = size
        else:
            raise RuntimeError("size参数请设置为int或者tuple")

    def __call__(self, img, mask):
        img = img.resize(self.size)
        mask = mask.resize(self.size)
        return img, mask

def make_dataset(root, prefix=('jpg', 'png')):
    img_path = root[0]
    gt_path = root[1]
    img_list = [os.path.splitext(f)[0] for f in os.listdir(img_path) if f.endswith(prefix[0])]
    return [(os.path.join(img_path, img_name + prefix[0]), os.path.join(gt_path, img_name + prefix[1])) for img_name in img_list]


# 仅针对训练集
class ImageFolder(data.Dataset):
    def __init__(self, root, mode, in_size, prefix, use_bigt=False, split_rate=(1, 3)):
        """split_rate = label:unlabel"""
        assert isinstance(mode, str), 'isTrain参数错误,应该为bool类型'
        self.mode = mode
        self.use_bigt = use_bigt
        self.split_rate = split_rate
        self.r_l_rate = split_rate[1] // split_rate[0]

        self.root_labeled = root[0]
        self.imgs_labeled = make_dataset(self.root_labeled, prefix=prefix)

        len_labeled = len(self.imgs_labeled)
        self.length = len_labeled

        self.root_unlabeled = root[1]
        self.imgs_unlabeled = make_dataset(self.root_unlabeled, prefix=prefix)
        
        len_unlabeled = self.r_l_rate * len_labeled
        
        self.imgs_unlabeled = self.imgs_unlabeled * (self.r_l_rate + math.ceil(len_labeled / len_unlabeled))  # 扩展无标签的数据列表
        self.imgs_unlabeled = self.imgs_unlabeled[0:len_unlabeled]

        print(f"使用比例为:{len_labeled / len_unlabeled}")

        # 仅是为了简单而仅使用一种变换
        self.train_joint_transform = JointResize(in_size)
        self.train_img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 处理的是Tensor
        ])
        # ToTensor 操作会将 PIL.Image 或形状为 H×W×D,数值范围为 [0, 255] 的 
        # np.ndarray 转换为形状为 D×H×W,
        # 数值范围为 [0.0, 1.0] 的 torch.Tensor。
        self.train_gt_transform = transforms.ToTensor()

    def __getitem__(self, index):
        # 这里一次性读取最简化比例数量的样本,所有的样本需要单独处理
        img_labeled_path, gt_labeled_path = self.imgs_labeled[index]  
        # 0, 1 => 850
        img_labeled = Image.open(img_labeled_path).convert('RGB')
        img_labeled_name = (img_labeled_path.split(os.sep)[-1]).split('.')[0]

        gt_labeled = Image.open(gt_labeled_path).convert('L')
        back_gt_labeled = gt_labeled  
        # 用于无标签数据使用联合调整函数的时候代替无标签数据真值进行占位
        img_labeled, gt_labeled = self.train_joint_transform(img_labeled, gt_labeled)
        img_labeled = self.train_img_transform(img_labeled)
        gt_labeled = self.train_gt_transform(gt_labeled)
        if self.use_bigt:
            gt_labeled = gt_labeled.ge(0.5).float()  # 二值化
        data_labeled = [img_labeled, gt_labeled, img_labeled_name]
        
        data_unlabeled = [[], []]
        for idx_periter in range(self.r_l_rate):
            # 这里不再使用真值,直接使用`_`接收
            img_unlabeled_path, _ = self.imgs_unlabeled[index//self.r_l_rate+idx_periter]  # 0, 1, 2, 3 => 3*850
            img_unlabeled = Image.open(img_unlabeled_path).convert('RGB')
            img_unlabeled_name = (img_unlabeled_path.split(os.sep)[-1]).split('.')[0]

            img_unlabeled, _ = self.train_joint_transform(img_unlabeled, back_gt_labeled)
            # 这里为了使用那个联合调整的转换类,使用上面的target进行替代,但是要注意,不要再返回了
            img_unlabeled = self.train_img_transform(img_unlabeled)
                        
            data_unlabeled[0].append(img_unlabeled)
            data_unlabeled[1].append(img_unlabeled_name)

        return data_labeled, data_unlabeled  # 输出名字方便比较

    def __len__(self):
        return self.length
    
print(f" ==>> 使用的训练集 <<==\n -->> LABELED_PATH:{LABELED_PATH}\n -->> UNLABELED_PATH:{UNLABELED_PATH}")
train_set = ImageFolder((LABELED_PATH, UNLABELED_PATH), "train", 320, prefix=('.jpg', '.png'), use_bigt=True, split_rate=(12, 36))
# 由于train_set内部的比例顺序已经被固定到每一次iter中,所以可以使用`shuffle=True`
train_loader = DataLoader(train_set, batch_size=12, num_workers=8, shuffle=True, drop_last=False, pin_memory=True)  

for train_idx, train_data in enumerate(train_loader):
    data_labeled, data_unlabeled = train_data
    
    train_labeled_inputs, train_labeled_gts, train_labeled_names = data_labeled
    print(train_labeled_inputs.size(), train_labeled_gts.size(), train_labeled_names)
    
    train_unlabeled_inputs_list, train_unlabeled_names = data_unlabeled
    train_unlabeled_inputs = torch.cat(train_unlabeled_inputs_list, dim=0)
    print(train_unlabeled_inputs.size(), train_unlabeled_names)
    
    train_labeled_inputs_batchsize = train_labeled_inputs.size(0)
    train_unlabeled_inputs_batchsize = train_unlabeled_inputs.size(0)
    
    # 正常训练中下面应该有,这里为了方便就关掉了,这里之所以不先进行cat再进行to(dev),
    # 是为了便于后面ema_model输入的时候使用一个已经在gpu上的张量,免去了再次搬运的麻烦
    # train_labeled_inputs = train_labeled_inputs.to(dev)
    # train_unlabeled_inputs = train_unlabeled_inputs.to(dev)
    # train_gts = train_labeled_gts.to(self.dev)
    train_inputs = torch.cat([train_labeled_inputs, train_unlabeled_inputs], dim=0)

    # otr_total = net(train_inputs)
    # labeled_otr, unlabeled_otr = otr_total.split((train_labeled_inputs_batchsize, train_unlabeled_inputs_batchsize), dim=0)
    # with torch.no_grad():
    #     ema_unlabeled_otr = ema_model(train_unlabeled_inputs)
    print(" ==>> 一个Batch结束了 <<== ")
    if train_idx == 2:
        break
print(" ==>> 一个Epoch结束了 <<== ")
 ==>> 使用的训练集 <<==
 -->> LABELED_PATH:['/kaggle/input/pascal-s/Pascal-S/Image', '/kaggle/input/pascal-s/Pascal-S/Mask']
 -->> UNLABELED_PATH:['/kaggle/input/ecssd/ECSSD/Image', '/kaggle/input/ecssd/ECSSD/Mask']
使用比例为:0.3333333333333333
torch.Size([12, 3, 320, 320]) torch.Size([12, 1, 320, 320]) ('299', '566', '138', '678', '700', '457', '266', '310', '810', '743', '469', '592')
torch.Size([36, 3, 320, 320]) [('0387', '0094', '0578', '0462', '0399', '0377', '0807', '0970', '0287', '0591', '0514', '0500'), ('0508', '0069', '0818', '0314', '0068', '0453', '0850', '0749', '0469', '0252', '0572', '0914'), ('0847', '0232', '0609', '0716', '0287', '0457', '0294', '0225', '0591', '0538', '0626', '0931')]
 ==>> 一个Batch结束了 <<== 
torch.Size([12, 3, 320, 320]) torch.Size([12, 1, 320, 320]) ('26', '771', '37', '814', '248', '389', '848', '3', '66', '153', '448', '227')
torch.Size([36, 3, 320, 320]) [('0322', '0464', '0972', '0734', '0043', '0800', '0483', '0807', '0029', '0425', '0976', '0741'), ('0054', '0527', '0683', '0694', '0612', '0390', '0910', '0850', '0548', '0260', '0335', '0406'), ('0761', '0586', '0936', '0501', '0073', '0381', '0544', '0294', '0007', '0633', '0505', '0322')]
 ==>> 一个Batch结束了 <<== 
torch.Size([12, 3, 320, 320]) torch.Size([12, 1, 320, 320]) ('805', '635', '739', '56', '80', '78', '496', '575', '359', '379', '55', '354')
torch.Size([36, 3, 320, 320]) [('0032', '0164', '0314', '0407', '0165', '0734', '0540', '0501', '0137', '0058', '0740', '0053'), ('0470', '0464', '0716', '0740', '0413', '0694', '0671', '0834', '0707', '0387', '0186', '0876'), ('0053', '0527', '0601', '0186', '0800', '0501', '0218', '0524', '0679', '0508', '0588', '0578')]
 ==>> 一个Batch结束了 <<== 
 ==>> 一个Epoch结束了 <<== 

补充

上面的操作中,也可以考虑将img_unlabeledimg_labeled直接按照比例放到一起,而真值部分仅是返回gt_labeled,同时img_unlabeled_nameimg_labeled_name一起返回,下面是例子:

    def __getitem__(self, index):
        # 这里一次性读取最简化比例数量的样本,所有的样本需要单独处理
        total_img, labeled_gt, total_name = [], [], []
        
        img_labeled_path, gt_labeled_path = self.imgs_labeled[index]  # 0, 1 => 850
        img_labeled = Image.open(img_labeled_path).convert('RGB')
        img_labeled_name = (img_labeled_path.split(os.sep)[-1]).split('.')[0]

        gt_labeled = Image.open(gt_labeled_path).convert('L')
        back_gt_labeled = gt_labeled  
        # 用于无标签数据使用联合调整函数的时候代替无标签数据真值进行占位
        img_labeled, gt_labeled = self.train_joint_transform(img_labeled, gt_labeled)
        img_labeled = self.train_img_transform(img_labeled)
        gt_labeled = self.train_gt_transform(gt_labeled)
        if self.use_bigt:
            gt_labeled = gt_labeled.ge(0.5).float()  # 二值化
        total_img.append(img_labeled)
        labeled_gt.append(gt_labeled)
        total_name.append(img_labeled_name)
        
        for idx_periter in range(self.r_l_rate):
            # 这里不再使用真值,直接使用`_`接收
            img_unlabeled_path, _ = self.imgs_unlabeled[index//self.r_l_rate+idx_periter]  # 0, 1, 2, 3 => 3*850
            img_unlabeled = Image.open(img_unlabeled_path).convert('RGB')
            img_unlabeled_name = (img_unlabeled_path.split(os.sep)[-1]).split('.')[0]

            img_unlabeled, _ = self.train_joint_transform(img_unlabeled, back_gt_labeled)  
            # 这里为了使用那个联合调整的转换类,使用上面的target进行替代,但是要注意,不要再返回了
            img_unlabeled = self.train_img_transform(img_unlabeled)
                        
            total_img.append(img_unlabeled)
            total_name.append(img_unlabeled_name)

        return total_img, labeled_gt, total_name  # 输出名字方便比较

这样在返回之后只需要对数据进行分割后处理即可,但是这里的分割需要按照间隔分割,并不方便。

方法三:改造DataLoader

这一点主要受到了mean-teacher的启发。

class TwoStreamBatchSampler(Sampler):
    """Iterate two sets of indices
    An 'epoch' is one iteration through the primary indices.
    During the epoch, the secondary indices are iterated through
    as many times as needed.
    """
    def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
        self.primary_indices = primary_indices
        self.secondary_indices = secondary_indices
        self.secondary_batch_size = secondary_batch_size
        self.primary_batch_size = batch_size - secondary_batch_size

        assert len(self.primary_indices) >= self.primary_batch_size > 0
        assert len(self.secondary_indices) >= self.secondary_batch_size > 0

    def __iter__(self):
        primary_iter = iterate_once(self.primary_indices)
        secondary_iter = iterate_eternally(self.secondary_indices)
        return (
            primary_batch + secondary_batch
            for (primary_batch, secondary_batch)
            in  zip(grouper(primary_iter, self.primary_batch_size),
                    grouper(secondary_iter, self.secondary_batch_size))
        )

    def __len__(self):
        return len(self.primary_indices) // self.primary_batch_size

调用的时候:

    dataset = torchvision.datasets.ImageFolder(traindir, train_transformation)

    if args.labels:
        with open(args.labels) as f:
            labels = dict(line.split(' ') for line in f.read().splitlines())
        labeled_idxs, unlabeled_idxs = data.relabel_dataset(dataset, labels)

    if args.exclude_unlabeled:
        sampler = SubsetRandomSampler(labeled_idxs)
        batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=True)
    elif args.labeled_batch_size:
        batch_sampler = data.TwoStreamBatchSampler(
            unlabeled_idxs, labeled_idxs, args.batch_size, args.labeled_batch_size)
    else:
        assert False, "labeled batch size {}".format(args.labeled_batch_size)

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_sampler=batch_sampler,
                                               num_workers=args.workers,
                                               pin_memory=True)

这一部分需要分析下DataLoader的几个参数。

参考资料:

  • https://blog.csdn.net/u014380165/article/details/79058479
  • pytorch学习笔记(十四): DataLoader源码阅读 https://blog.csdn.net/u012436149/article/details/78545766
  • Pytorch中的数据加载艺术 http://studyai.com/article/11efc2bf
class DataLoader(object):
    r"""
    Data loader. Combines a dataset and a sampler, and provides
    single- or multi-process iterators over the dataset.

    Arguments:
        dataset (Dataset): dataset from which to load the data. 
            自定义的Dataset类的子类,实现了基本的数据的读取流程,例如获取地址列表、
            根据索引打开图片、图片预处理等等
        batch_size (int, optional): how many samples per batch to load 
            (default: ``1``). 
            如字面含义,确定了batchsize,可知batch是对数个样本的包装
        shuffle (bool, optional): set to ``True`` to have the data reshuffled 
            at every epoch (default: ``False``). 
            是否每个周期都打乱数据的原始顺序,一般是训练的时候为True,测试为False
        sampler (Sampler, optional): defines the strategy to draw samples 
            from the dataset. If specified, ``shuffle`` must be False. 
            定义了从数据中采样的策略,一次返回一个样本的索引,这是Sampler的子类,
            此时必须关闭shuffle操作,相当于你得自己实现
        batch_sampler (Sampler, optional): like sampler, but returns a batch of
            indices at a time. Mutually exclusive with :attr:`batch_size`,
            :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
            和sampler类似,但是这个跟更进一步,定义了针对batch级别的数据的采样策略,与
            batch_size/shuffle/sampler/drop_last互斥,一次可以返回一个batch的索引
        num_workers (int, optional): how many subprocesses to use for data
            loading. 0 means that the data will be loaded in the main process.
            (default: ``0``)
            读取数据使用的子进程数目,在一定程度上可以加快数据读取
        collate_fn (callable, optional): merges a list of samples to form a mini-batch.
            是一个可调用的对象,用来合并样本,构建mini-batch
        pin_memory (bool, optional): If ``True``, the data loader will copy tensors
            into CUDA pinned memory before returning them.  If your data elements
            are a custom type, or your ``collate_fn`` returns a batch that is a custom type
            see the example below.
            如果为True,数据加载器在返回前将张量复制到CUDA固定内存中
        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
            if the dataset size is not divisible by the batch size. If ``False`` and
            the size of dataset is not divisible by the batch size, then the last batch
            will be smaller. (default: ``False``)
            是否丢弃每个周期最后一个不完整的batch,如果存在的话
        timeout (numeric, optional): if positive, the timeout value for collecting a batch
            from workers. Should always be non-negative. (default: ``0``)
            是一个非负值,来指定从workers中获取数据的timeout参数,超过这个时间还没读取到数据的话就会报错
        worker_init_fn (callable, optional): If not ``None``, this will be called on each
            worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
            input, after seeding and before data loading. (default: ``None``) 
            如果不是None,将在每个worker子进程上调用,使用worker id作为输入,在seeding之后以及数据加载之前
            (这个不太懂,目前还不理解用法)

    .. note:: When ``num_workers != 0``, the corresponding worker processes are created each time
              iterator for the DataLoader is obtained (as in when you call
              ``enumerate(dataloader,0)``).
              At this point, the dataset, ``collate_fn`` and ``worker_init_fn`` are passed to each
              worker, where they are used to access and initialize data based on the indices
              queued up from the main process. This means that dataset access together with
              its internal IO, transforms and collation runs in the worker, while any
              shuffle randomization is done in the main process which guides loading by assigning
              indices to load. Workers are shut down once the end of the iteration is reached.

              Since workers rely on Python multiprocessing, worker launch behavior is different
              on Windows compared to Unix. On Unix fork() is used as the default
              muliprocessing start method, so child workers typically can access the dataset and
              Python argument functions directly through the cloned address space. On Windows, another
              interpreter is launched which runs your main script, followed by the internal
              worker function that receives the dataset, collate_fn and other arguments
              through Pickle serialization.

              This separate serialization means that you should take two steps to ensure you
              are compatible with Windows while using workers
              (this also works equally well on Unix):

              - Wrap most of you main script's code within ``if __name__ == '__main__':`` block,
                to make sure it doesn't run again (most likely generating error) when each worker
                process is launched. You can place your dataset and DataLoader instance creation
                logic here, as it doesn't need to be re-executed in workers.
              - Make sure that ``collate_fn``, ``worker_init_fn`` or any custom dataset code
                is declared as a top level def, outside of that ``__main__`` check. This ensures
                they are available in workers as well
                (this is needed since functions are pickled as references only, not bytecode).

              By default, each worker will have its PyTorch seed set to
              ``base_seed + worker_id``, where ``base_seed`` is a long generated
              by main process using its RNG. However, seeds for other libraies
              may be duplicated upon initializing workers (w.g., NumPy), causing
              each worker to return identical random numbers. (See
              :ref:`dataloader-workers-random-seed` section in FAQ.) You may
              use :func:`torch.initial_seed()` to access the PyTorch seed for
              each worker in :attr:`worker_init_fn`, and use it to set other
              seeds before data loading.

    .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
                 unpicklable object, e.g., a lambda function.

    The default memory pinning logic only recognizes Tensors and maps and iterables
    containg Tensors.  By default, if the pinning logic sees a batch that is a custom type
    (which will occur if you have a ``collate_fn`` that returns a custom batch type),
    or if each element of your batch is a custom type, the pinning logic will not
    recognize them, and it will return that batch (or those elements)
    without pinning the memory.  To enable memory pinning for custom batch or data types,
    define a ``pin_memory`` method on your custom type(s).
    默认的内存固定逻辑仅识别张量,包含张量的映射和迭代。
    默认情况下,如果固定逻辑看到一个自定义类型的批处理(如果您有一个返回自定义批处理类型
    的collate_fn,或者如果批处理的每个元素都是自定义类型,则会发生这种情况) 逻辑将无法
    识别它们,它将返回该批次(或那些元素)并且不固定内存。要为自定义批处理或数据类型启用
    内存固定,请在自定义类型上定义`pin_memory`方法。

    Example::

        class SimpleCustomBatch:
            def __init__(self, data):
                transposed_data = list(zip(*data))
                self.inp = torch.stack(transposed_data[0], 0)
                self.tgt = torch.stack(transposed_data[1], 0)

            def pin_memory(self):
                self.inp = self.inp.pin_memory()
                self.tgt = self.tgt.pin_memory()
                return self

        def collate_wrapper(batch):
            return SimpleCustomBatch(batch)

        inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
        tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
        dataset = TensorDataset(inps, tgts)

        loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                            pin_memory=True)

        for batch_ndx, sample in enumerate(loader):
            print(sample.inp.is_pinned())
            print(sample.tgt.is_pinned())

    """

这里主要关注参数中的samplerbatch_sampler以及collate_fn的用法。

samplerbatch_sampler

首先可以看默认要求是如何:

        # batch_sampler指定的时候,要求batch_size=1/shuule=False/sampler=None/drop_last=False
        # 也就是batch_sampler需要完成读取并划分batch、置乱数据、处理最后的batch等需求
        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler option is mutually exclusive '
                                 'with batch_size, shuffle, sampler, and '
                                 'drop_last')
            self.batch_size = None
            self.drop_last = None

        # sampler指定的时候,要求shuffle=False,也就是sampler需要完成数据的获取打乱的需求
        if sampler is not None and shuffle:
            raise ValueError('sampler option is mutually exclusive with '
                             'shuffle')

        if self.num_workers < 0:
            raise ValueError('num_workers option cannot be negative; '
                             'use num_workers=0 to disable multiprocessing.')

        # batch_sampler和sampler都没有指定的时候,sampler根据shuffle来确定默认的设置为
        # RandomSampler和SequentialSampler,可以看出来,一个是随机抽取(所谓置乱)一个是
        # 按照顺序抽取,而batch_sampler设置为BatchSampler,所以说,若想要自己实现batch_sampler
        # 或者sampler,只要模仿这三个类即可
        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler
        self.__initialized = True

简言之,采样器定义了索引(index)的产生规则,按指定规则去产生索引,从而控制数据的读取机制(http://studyai.com/article/11efc2bf)

查看这几个类,这里的代码来自V1.1.0

import torch
from torch._six import int_classes as _int_classes


class Sampler(object):
    r"""Base class for all Samplers.
    Every Sampler subclass has to provide an __iter__ method, providing a way
    to iterate over indices of dataset elements, and a __len__ method that
    returns the length of the returned iterators.
    
    每个Sampler的子类(后面的那些采集数据的类)都要包含下面这几个方法
    """

    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

    def __len__(self):
        # 在V1.2.0中,没有了这个需求:https://github.com/pytorch/pytorch/blob/v1.2.0/torch/utils/data/sampler.py#L23-L48
        raise NotImplementedError


class SequentialSampler(Sampler):
    r"""Samples elements sequentially, always in the same order.
    Arguments:
        data_source (Dataset): dataset to sample from
        
    保证每个周期按照固定的顺序读取,所以这里直接使用了range(len(self.data_source))作为顺序
    """

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)


class RandomSampler(Sampler):
    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
    If with replacement, then user can specify ``num_samples`` to draw.
    Arguments:
        data_source (Dataset): dataset to sample from
        replacement (bool): samples are drawn with replacement if ``True``, default=``False``
        num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
            is supposed to be specified only when `replacement` is ``True``.
            
    返回随机打乱后的索引迭代器
    """

    def __init__(self, data_source, replacement=False, num_samples=None):
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples

        if not isinstance(self.replacement, bool):
            raise ValueError("replacement should be a boolean value, but got "
                             "replacement={}".format(self.replacement))

        if self._num_samples is not None and not replacement:
            raise ValueError("With replacement=False, num_samples should not be specified, "
                             "since a random permute will be performed.")

        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(self.num_samples))

    # 关于该装饰器:https://www.programiz.com/python-programming/property
    # 这里为私有属性提供了一个接口
    @property
    def num_samples(self):
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self):
        n = len(self.data_source)
        if self.replacement:
            # 这里可以重新制定索引列表长度(=self.num_samples),索引列表最大值(=len(self.data_source)是固定的
            # 同时,self.replacement参数指定了是否使用可重复抽样
            return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
        # torch.randperm(n) Returns a random permutation of integers from 0 to n - 1. 
        # https://pytorch.org/docs/1.1.0/torch.html#torch.randperm
        return iter(torch.randperm(n).tolist())

    def __len__(self):
        return self.num_samples


class SubsetRandomSampler(Sampler):
    r"""Samples elements randomly from a given list of indices, without replacement.
    Arguments:
        indices (sequence): a sequence of indices
        
    这里是对于原有索引序列取出一个子集
    """

    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return (self.indices[i] for i in torch.randperm(len(self.indices)))

    def __len__(self):
        return len(self.indices)


class WeightedRandomSampler(Sampler):
    r"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
    Args:
        weights (sequence)   : a sequence of weights, not necessary summing up to one
        num_samples (int): number of samples to draw
        replacement (bool): if ``True``, samples are drawn with replacement.
            If not, they are drawn without replacement, which means that when a
            sample index is drawn for a row, it cannot be drawn again for that row.
            为True的时候,可以理解为有放回抽取,False可以理解为无放回抽取
    Example:
        >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
        [0, 0, 0, 1, 0]
        >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
        [0, 1, 4, 3, 2]
        
    这里根据对应的概率来采样样本,确定索引迭代器
    """

    def __init__(self, weights, num_samples, replacement=True):
        if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
                num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(num_samples))
        if not isinstance(replacement, bool):
            raise ValueError("replacement should be a boolean value, but got "
                             "replacement={}".format(replacement))
        self.weights = torch.as_tensor(weights, dtype=torch.double)
        self.num_samples = num_samples
        self.replacement = replacement

    def __iter__(self):
        # torch.multinomial多项式分布根据权重进行采样:
        # https://baike.baidu.com/item/%E5%A4%9A%E9%A1%B9%E5%88%86%E5%B8%83
        # https://pytorch.org/docs/1.1.0/torch.html#torch.multinomial
        return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())

    def __len__(self):
        return self.num_samples

 
class BatchSampler(Sampler):
    r"""Wraps another sampler to yield a mini-batch of indices.
    Args:
        sampler (Sampler): Base sampler.
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``
    Example:
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
        
    BatchSampler 是基于 Sampler 来构造的: BatchSampler = Sampler + BatchSize
    """

    def __init__(self, sampler, batch_size, drop_last):
        if not isinstance(sampler, Sampler):
            raise ValueError("sampler should be an instance of "
                             "torch.utils.data.Sampler, but got sampler={}"
                             .format(sampler))
        if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integer value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        # 这里使用yield生成最终的迭代batch
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        # 这里判断了最后一个可能存在的不完整的batch
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            # 舍弃最后一个不完整的batch,向下取整
            return len(self.sampler) // self.batch_size
        else:
            # 若能整除,则self.batch_size-1整除后没有影响,因为结果为0
            # 若是不能整除,则len(self.sampler)必然要比self.batch_size的整数倍多出
            # [1, self.batch_size-1]的这个闭区间范围的值,
            # 所以再加上一个该范围最大的值self.batch_size-1必定会位于
            # [len(self.sampler), 
            #  (len(self.sampler) // self.batch_size+1) * self.batch_size]
            # 该区间内,结果正好多出来一个需要的(+1)
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

由上可见,Sampler本质就是个具有特定规则的可迭代对象,但只能单例迭代。

[x for x in range(10)], range(10)就是个最基本的Sampler,每次循环只能取出其中的一个值.

sampler = [x for x in range(10)]
print(f"原始Sampler:{sampler}")

from torch.utils.data.sampler import SequentialSampler
print(f"顺序采样:{[x for x in SequentialSampler(sampler)]}")

from torch.utils.data.sampler import RandomSampler
print(f"随机重复采样:{[x for x in RandomSampler(data_source=sampler, replacement=True, num_samples=5)]}")
print(f"随机不重复采样:{[x for x in RandomSampler(data_source=sampler, replacement=False)]}")
原始Sampler:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
顺序采样:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
随机重复采样:[9, 4, 9, 6, 4]
随机不重复采样:[9, 6, 1, 3, 2, 7, 0, 4, 5, 8]

collate_fn

参考资料:

  • https://jdhao.github.io/2017/10/23/pytorch-load-data-and-make-batch/#loading-variable-size-input-images
  • https://www.cnblogs.com/king-lps/p/10990304.html

查看源代码https://github.com/pytorch/pytorch/blob/v1.1.0/torch/utils/data/_utils/collate.py#L31

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size
    
    一般是输入的batch中第一位为图像,第二位为标签,所以这里直接判断第一位的类型。
    第二位上也需要考虑是否可以被stack,
    对于分割任务而言,真值也是图片,所以也得保证图片有着相同的大小
    将batch中的数据进行整理,将一系列图像和目标打包为张量(张量的第一个维度为批大小)。
    
    The default `collate_fn` expects all the images in a batch to have the same size 
    because it uses `torch.stack()` to pack the images. If the images provided by 
    Dataset have variable size, you have to provide your custom `collate_fn`.
    """

    elem_type = type(batch[0])
    if isinstance(batch[0], torch.Tensor):
        out = None
        if _use_shared_memory:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = batch[0].storage()._new_shared(numel)
            out = batch[0].new(storage)
        return torch.stack(batch, 0, out=out)
    
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(error_msg_fmt.format(elem.dtype))

            return default_collate([torch.from_numpy(b) for b in batch])
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
        
    elif isinstance(batch[0], float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(batch[0], int_classes):
        return torch.tensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    
    elif isinstance(batch[0], container_abcs.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'):  # namedtuple
        return type(batch[0])(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(batch[0], container_abcs.Sequence):
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError((error_msg_fmt.format(type(batch[0]))))

这里个根据输入的类型来实现对于不同类别的数据的返回与划分。可见有几处使用了递归的操作重复用了该函数。

import os

import torch.utils.data as data
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
import math


class JointResize(object):
    def __init__(self, size):
        if isinstance(size, int):
            self.size = (size, size)
        elif isinstance(size, tuple):
            self.size = size
        else:
            raise RuntimeError("size参数请设置为int或者tuple")

    def __call__(self, img, mask):
        img = img.resize(self.size)
        mask = mask.resize(self.size)
        return img, mask

def make_dataset(root, prefix=('jpg', 'png')):
    img_path = root[0]
    gt_path = root[1]
    img_list = [os.path.splitext(f)[0] for f in os.listdir(img_path) if f.endswith(prefix[0])]
    return [(os.path.join(img_path, img_name + prefix[0]), os.path.join(gt_path, img_name + prefix[1])) for img_name in img_list]


# 仅针对训练集
class ImageFolder(data.Dataset):
    def __init__(self, root, mode, in_size, prefix, use_bigt=False, split_rate=(1, 3)):
        """split_rate = label:unlabel"""
        assert isinstance(mode, str), 'mode参数错误,应该为str类型'
        self.mode = mode
        self.use_bigt = use_bigt
        self.split_rate = split_rate
        self.r_l_rate = split_rate[1] // split_rate[0]

        self.root_labeled = root[0]
        self.imgs_labeled = make_dataset(self.root_labeled, prefix=prefix)

        len_labeled = len(self.imgs_labeled)
        self.length = len_labeled

        self.root_unlabeled = root[1]
        self.imgs_unlabeled = make_dataset(self.root_unlabeled, prefix=prefix)
        
        len_unlabeled = self.r_l_rate * len_labeled
        
        self.imgs_unlabeled = self.imgs_unlabeled * (self.r_l_rate + math.ceil(len_labeled / len_unlabeled))  # 扩展无标签的数据列表
        self.imgs_unlabeled = self.imgs_unlabeled[0:len_unlabeled]

        print(f"使用比例为:{len_labeled / len_unlabeled}")

        # 仅是为了简单而仅使用一种变换
        self.train_joint_transform = JointResize(in_size)
        self.train_img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 处理的是Tensor
        ])
        # ToTensor 操作会将 PIL.Image 或形状为 H×W×D,数值范围为 [0, 255] 的 np.ndarray 转换为形状为 D×H×W,
        # 数值范围为 [0.0, 1.0] 的 torch.Tensor。
        self.train_gt_transform = transforms.ToTensor()

    def __getitem__(self, index):
        # 这里一次性读取最简化比例数量的样本,所有的样本需要单独处理
        img_labeled_path, gt_labeled_path = self.imgs_labeled[index]  # 0, 1 => 850
        img_labeled = Image.open(img_labeled_path).convert('RGB')
        img_labeled_name = (img_labeled_path.split(os.sep)[-1]).split('.')[0]

        gt_labeled = Image.open(gt_labeled_path).convert('L')
        back_gt_labeled = gt_labeled  
        # 用于无标签数据使用联合调整函数的时候代替无标签数据真值进行占位
        img_labeled, gt_labeled = self.train_joint_transform(img_labeled, gt_labeled)
        img_labeled = self.train_img_transform(img_labeled)
        gt_labeled = self.train_gt_transform(gt_labeled)
        if self.use_bigt:
            gt_labeled = gt_labeled.ge(0.5).float()  # 二值化
        data_labeled = [img_labeled, gt_labeled, img_labeled_name]
        
        data_unlabeled = [[], []]
        for idx_periter in range(self.r_l_rate):
            # 这里不再使用真值,直接使用`_`接收
            img_unlabeled_path, _ = self.imgs_unlabeled[index // self.r_l_rate + idx_periter]  
            # 0, 1, 2, 3 => 3*850
            img_unlabeled = Image.open(img_unlabeled_path).convert('RGB')
            img_unlabeled_name = (img_unlabeled_path.split(os.sep)[-1]).split('.')[0]

            img_unlabeled, _ = self.train_joint_transform(img_unlabeled, back_gt_labeled)  
            # 这里为了使用那个联合调整的转换类,使用上面的target进行替代,但是要注意,不要再返回了
            img_unlabeled = self.train_img_transform(img_unlabeled)
                        
            data_unlabeled[0].append(img_unlabeled)
            data_unlabeled[1].append(img_unlabeled_name)

        return data_labeled, data_unlabeled  # 输出名字方便比较

    def __len__(self):
        return self.length
    
    
def my_collate(batch):
    # 针对送进来的一个batch的数据进行整合,batch的各项表示各个样本
    # batch 仅有一项 batch[0] 对应于下面的 train_data
    # batch[0][0], batch[0][1] <==> data_labeled, data_unlabeled = train_data
    # batch[0][0][0], batch[0][0][1], batch[0][0][2] <==> train_labeled_inputs, train_labeled_gts, train_labeled_names = data_labeled
    # batch[0][1][0], batch[0][2][1] <==> train_unlabeled_inputs_list, train_unlabeled_names = data_unlabeled
    
    # 最直接的方法:
    train_labeled_inputs, train_labeled_gts, train_labeled_names = [], [], []
    train_unlabeled_inputs_list, train_unlabeled_names = [], []
    for batch_iter in batch:
        x, y = batch_iter
        train_labeled_inputs.append(x[0])
        train_labeled_gts.append(x[1])
        train_labeled_names.append(x[2])
        
        train_unlabeled_inputs_list += y[0]
        train_unlabeled_names += y[1]

    train_labeled_inputs = torch.stack(train_labeled_inputs, 0)
    train_unlabeled_inputs_list = torch.stack(train_unlabeled_inputs_list, 0)
    train_labeled_gts = torch.stack(train_labeled_gts, 0)
    print(train_unlabeled_inputs_list.size())
    return ([train_labeled_inputs, train_unlabeled_inputs_list], 
            [train_labeled_gts],
            [train_labeled_names, train_unlabeled_names])

print(f" ==>> 使用的训练集 <<==\n -->> LABELED_PATH:{LABELED_PATH}\n -->> UNLABELED_PATH:{UNLABELED_PATH}")
train_set = ImageFolder((LABELED_PATH, UNLABELED_PATH), "train", 320, prefix=('.jpg', '.png'), use_bigt=True, split_rate=(3, 9))
# a simple custom collate function, just to show the idea
train_loader = DataLoader(train_set, batch_size=3, num_workers=4, collate_fn=my_collate, shuffle=True, drop_last=False, pin_memory=True)
print(" ==>> data_loader构建完毕 <<==")

for train_idx, train_data in enumerate(train_loader):

    train_inputs, train_gts, train_names = train_data
    
    train_labeled_inputs, train_unlabeled_inputs = train_inputs
    train_labeled_gts = train_gts[0]
    train_labeled_names, train_unlabeled_names = train_names
    print("-->>", train_labeled_inputs.size(), train_labeled_gts.size(), train_labeled_names)
    print("-->>", train_unlabeled_inputs.size(), train_unlabeled_names)
    
    train_labeled_inputs_batchsize = train_labeled_inputs.size(0)
    train_unlabeled_inputs_batchsize = train_unlabeled_inputs.size(0)
    
    # 正常训练中下面应该有,这里为了方便就关掉了,这里之所以不先进行cat再进行to(dev),是为了便于后面ema_model输入的时候使用一个已经在gpu上的张量,免去了再次搬运的麻烦
    # train_labeled_inputs = train_labeled_inputs.to(dev)
    # train_unlabeled_inputs = train_unlabeled_inputs.to(dev)
    # train_gts = train_labeled_gts.to(self.dev)
    train_inputs = torch.cat([train_labeled_inputs, train_unlabeled_inputs], dim=0)

    # otr_total = net(train_inputs)
    # labeled_otr, unlabeled_otr = otr_total.split((train_labeled_inputs_batchsize, train_unlabeled_inputs_batchsize), dim=0)
    # with torch.no_grad():
    #     ema_unlabeled_otr = ema_model(train_unlabeled_inputs)
    print(" ==>> 一个Batch结束了 <<== ")
    if train_idx == 0:
        break
print(" ==>> 一个Epoch结束了 <<== ")
 ==>> 使用的训练集 <<==
 -->> LABELED_PATH:['/kaggle/input/pascal-s/Pascal-S/Image', '/kaggle/input/pascal-s/Pascal-S/Mask']
 -->> UNLABELED_PATH:['/kaggle/input/ecssd/ECSSD/Image', '/kaggle/input/ecssd/ECSSD/Mask']
使用比例为:0.3333333333333333
 ==>> data_loader构建完毕 <<==
torch.Size([9, 3, 320, 320])
torch.Size([9, 3, 320, 320])
torch.Size([9, 3, 320, 320])
torch.Size([9, 3, 320, 320])
torch.Size([9, 3, 320, 320])
torch.Size([9, 3, 320, 320])
torch.Size([9, 3, 320, 320])
torch.Size([9, 3, 320, 320])
-->> torch.Size([3, 3, 320, 320]) torch.Size([3, 1, 320, 320]) ['783', '5', '116']
-->> torch.Size([9, 3, 320, 320]) ['0817', '0128', '0743', '0214', '0763', '0344', '0818', '0609', '0809']
 ==>> 一个Batch结束了 <<== 
 ==>> 一个Epoch结束了 <<== 

More

  • 19
    点赞
  • 83
    收藏
    觉得还不错? 一键收藏
  • 21
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 21
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值