pytorch训练之数据加载步骤

本文章以ReID的数据加载为例。



from torch.utils.data import dataset, dataloader
from torchvision import transforms
一、建立自定义数据处理方法类:如随机擦除,随机裁剪等
代码:
class RandomErasing(object):
    def __init__(self,probability=0.5)
    
    def __call__(self, img)
        ...
        return img
二、建立数据预处理组合类实例:如图像翻转,归一化,向量化,擦除等
代码:
train_transform = transforms.Compose([
            transforms.Resize((384, 128), interpolation=3),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            RandomErasing(probability=0.5, mean=[0.0, 0.0, 0.0])
        ])

三、建立数据读取类:从本地路径进行数据加载,形成列表等
代码:
from torchvision.datasets.folder import default_loader  //解释见最后

class Market(dataset.Dataset):
    def __init__(self, transform, dtype, data_path):
        self.loader = default_loader

    def __getitem__(self, index):
        ...
        //根据路径生成图像与标签列表
        //加载图像
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)

        return img,target
    def __len__(self):
        return len(self.imgs)

四、生成torch数据流类实例:

self.train_loader = dataloader.DataLoader(self.trainset,
                                                  sampler=RandomSampler(self.trainset, batch_id=opt.batchid,
                                                                        batch_image=opt.batchimage),
                                                  batch_size=opt.batchid * opt.batchimage, num_workers=8,
                                                  pin_memory=True)
self.test_loader = dataloader.DataLoader(self.testset, batch_size=opt.batchtest, num_workers=8, pin_memory=True)

然后就可以在训练阶段使用迭代方法进行数据获取了。


torchvision.datasets.folder中的default_loader函数:

该函数主要分两种情况调用两个函数,一般采用pil_loader函数。

def pil_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)

def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else: #get_image_backend() == 'PIL'
        return pil_loader(path)
  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值