PyTorch日积月累_4_utilities和torchvision


这篇博客主要介绍 torch.utils.data的数据集载入、 torchvision.transforms的数据预处理。

torch.utils.data

首先需要通过构造单个数据的torch.utils.data.Dataset类 或者 torch.utils.data.IterableDataset类,前者是映射类型(map-type)数据集,后者是可迭代类型(Iterable-style)数据集;然后使用torch.utils.data.DataLoader类 载入数据。

构造映射类型的数据集

该种类型的数据集,每个数据都有一个索引(可能是数据的地址),通过该索引可以得到对应的数据。

构造方法

构造方法:继承torch.utils.data.Dataset类,并重写__getitem__方法和__len__方法,前者就是索引符号[]调用的方法,通过索引返回对应数据的张量,后者返回数据的数量。

class Dataset(object):
    def __getitem__(self,index):
        # index 数据索引,0 ~ 数据个数-1
        pass
        # 返回数据张量
	def __len__(self):
        # 返回数据的数量
        pass
例1 torchvision.DatasetFolder

torchvision中的DatasetFolder为例子进行说明,该类的使用场景是数据集文件夹下的每个子文件夹对应一个类别,在DatasetFolder的构造函数中,self.samples是一个(数据路径,预测目标)的列表,self.loader则可以根据路径读取数据:

class VisionDataset(data.Dataset):
    def __init__(self,root,transforms=None,transform=None,target_transform=None):
        # ...
        pass
    def __getitem__(self,index):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError
        
class DatasetFolder(VisionDataset):
    def __init__(...):
        self.samples = ...
        self.loader  = ...
        # ...
    def __getitem__(self,index):
        path, target = self.samples[index]
        sample = self.loader(path)
        ...
        return sample,target
    def __len__(self):
        return len(self.samples)

像这样,通过构建一个数据地址和数据索引的映射列表的方式,进行数据的读取的方式十分常见。

例2 用于猫狗分类的数据抽象类 DogCat

主要涉及到两个模块torch.utils.Datasettorch.utils.DataLoader

  • Dataset需要重写__getitem____len__方法

    from torch.utils import data
    import os,torch
    import matplotlib.pyplot as PIL
    import numpy as np
    
    class DogCat(data.Dataset):
        def __init__(self,root):
            super(DogCat,self).__init__()
            imgs = os.listdir(root)
            self.imgs = [os.path.join(root,img) for img in imgs]
            
        def __getitem__(self,index):
            img_path = self.imgs[index]
            pil_img = PIL.Image.open(img.path) 
            data = torch.from_numpy(np.asnumpy(pil_img)) # PIL Image -> tensor
            label = 1 if 'dog' in img.path.split('/')[-1] else 0
            return data,label
        
        def __len__(self):
            return len(self.imgs)
    
    

    重点是完成self.imgs存放图片地址

  • data.DataLoader 加载数据

    dataloader = DataLoader(dataset,batch_size,shuffle=False,sampler=None)
    # dataloader是一个可迭代对象
    batch_data,batch_labels = next(iter(dataloader))
    

问题:遇到损坏无法加载的图片应该怎么办?

class NewCatDog(CatDog):
    def __init__(self,root):
        super(NewCatDog).__init__()
        imgs = os.listdir(root)
        self.imgs = [os.join(root,img) for img in imgs]

    def __getitem__(self,index):
        try:
            return super(NewCatDog,self).__getitem__(index)
        except:
            return None,None
    def __len__(self):
        return len(self.imgs)

# 同时还应在DataLoader中修改
from torch.utils.data.dataloader import default_collate
def my_collate(batch):
    # 过滤None元素
    batch = list(filter(lambda x: x[0] is not None, batch))
    # 考虑全部损坏的极端情况
    if len(batch) ==0:return torch.Tensor()
    return default_collate(batch)

dataset = NewCatDog(root)
dataloader = DataLoader(dataset,collate_fn = my_collate)

不过更好的做法是,将好的图像替换这张损坏的图像。

Note:
  • 高负载行为放在 dataset中,例如加载图片
  • dataset中应避免放入需要修改的对象,否则在多进程/多线程中,需要加锁,但是dataloader又很难加锁,所以尽量不要修改,否则建议使用Python标准库中的Queue结构。

构造可迭代类型的数据集

该方法基于torch.utils.data.IterableDataset,不需要重写__getitem____len__方法,只需要通过重写__iter__方法,并考虑多进程分割数据的情况,分配索引。

构造方法:

class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self,start,end):
        super().__init__()
        assert end > start, "this example code only works with end >= start"
        self.start = start
        self.end = end
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None: # 单进程
            iter_start = self.start
            iter_end = self.end
        else:
            per_worker = int(math.ceil((self.end-self.start)/float(worker_info.num_workers))) # 将数据N等分
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker,self.end)
        return iter(range(iter_start,iter_end))

采样模块torch.utils.data.sampler

torch.utils.data.sampler

设置shuffle=True时,自动调用RandomSampler,此时会按照顺序一个个采样,当然也可以设置为WeightedRandomSampler,此时会按照权重进行采样,当inplacement==True时,允许采样器在一个epoch中重复采样某一个样本,一般设置为true,否则当样本数小于num_samplers时,会导致weights参数失效。

Note: 权重只与weights列表的比值有关

weights = [2 if label ==1 else 1 for data,label in dataset]
from torch.utils.data.sampler import  WeightedRandomSampler
sampler = WeightedRandomSampler(weights,num_samples=9,replacement=True)
dataloader = DataLoader(dataset,
                        batch_size=3,
                        sampler=sampler)
for datas, labels in dataloader:
    print(labels.tolist())

torchvision

torchvision概览

torchvision 主要包含三部分

  • models

  • datasets

    MNIST,CIFAR10,ImageNet,COCO

  • transforms

    • Tensor类型的处理: Normalize, ToPILImage
    • PIL Image 类型的处理: ToTensor, Crop, scale, padding

此外,torchvision.utils提供了两个函数make_grid

torchvision.transforms

torchvision.transforms主要实现两个功能:

  1. 针对Tensor类型

    1. Normalize(mean,std)
    2. ToPILImage()
  2. 针对PILImage类型

    1. ToTensor(),同时归一化[0,1]
    2. Scale,长宽比不变的情况下缩放图片,以较短边为准
    3. CenterCrop,RandomCrop,RandomResizeCrop
    4. Pad

利用torchvision.transforms.Compose()将其拼接起来,注意,这些操作都是定义了一个函数,后续还需要再调用函数,才能生效,相当于调用了__call__方法,则以下的调用方式是不对的

torchvision.transforms.ToTensor(img) # 不能直接调用

andomCrop,RandomResizeCrop4.Pad`

利用torchvision.transforms.Compose()将其拼接起来,注意,这些操作都是定义了一个函数,后续还需要再调用函数,才能生效,相当于调用了__call__方法,则以下的调用方式是不对的

torchvision.transforms.ToTensor(img) # 不能直接调用

在实际使用中,还需要注意进行维度的变化,尤其是单张图片的batch维度的变化。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值