pytorch--dataloader

注:以下图片来自sung kim的pytorch lecture
在dataset很小的情况下人工data feed比较简单,如下我们feed所有的data给model,然后利用output去计算损失函数和梯度,然后更新权重,但是当data size很大的情况,我们不能把data一次性feed给model,也不能计算出所有的梯度
在这里插入图片描述
所以通常的做法是我们把所有的data分成小的batches,然后go through each batch at once.
在这里插入图片描述
如何实现这些呢,利用pytorch的dataloader我们很容易做到这些,我们不需要关心如何进行划分,我们只需要关注可以被iterable的bacth,把它用于训练即可。
在这里插入图片描述
自定义dataloader需要自定义dataset,实例化该dataset,然后feed到DataLoader,其中可以规定bactch_size,其次shuffle意味着是否随机选取,当需要将该loader应用到两个以上进程数,可以利用num_workers 来定义。
在这里插入图片描述
enumerate(sequence, [start=0])
在这里插入图片描述
有很多数据集已经实现了dataset,如minist,不需要下载或者人工分割,如下面这样使用即可

import torch
import torchvision
import torchvision.transforms as transforms
transform=transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ]
)
trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,
                                      transform=transform)
trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,
                                        shuffle=True,num_workers=2)
testset=torchvision.datasets.CIFAR10(root='./data',train=False,
                                     download=True,transform=transform)
testloader=torch.utils.data.DataLoader(testset,batch_size=4,
                                       shuffle=False,num_workers=22)
classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')


下面是根据pytorch 官方文档关于dataloader详细介绍

custom datasets

torch.utils.data.Dataset是抽象类,定义一个新的dataset,需要继承这个抽象类,并且需要
重载__len__和__getitem__方法
__len_方法返回dataset的大小
__getitem__方法用于支持索引 如dataset[i]表示索引第i个样本
我们在初始化构造函数中读入csv文件,但是读入图片放在__getitem__方法内,这样可以节约内存,因为这样图片是在需要时才载入内存。

#自定义dataset
class FaceLandmarkDataset(Dataset):
    def __init__(self,csv_file,root_dir,transform=None):
        self.landmarks_frame=pd.read_csv(csv_file)
        self.root_dir=root_dir
        self.transform=transform
        
    def __len__(self):
        return len(self.landmarks_frame)
    
    def __getitem__(self,idx):
        if torch.is_tensor(idx):
            idx=idx.tolist()
        img_name=os.path.join(self.root_dir,self.landmarks_frame.iloc[idx,0])
        image=io.imread(img_name)
        landmarks=self.landmarks_frame.iloc[idx,1:]
        landmarks=np.array([landmarks])
        landmarks=landmarks.astype('float').reshape(-1,2)
        sample={'image':image,'landmarks':landmarks}
        
        if self.transform:
            sample=self.transform(sample)
        
        return sample
 #实例化该类
 face_dataset=FaceLandmarkDataset(csv_file='data/faces/face_landmarks.csv',root_dir='data/faces/')
 
transforms

所有的样本都不是同一大小,所以我们需要做一些预处理步骤,如下:
rescale:放缩到同一大小
randcrop:随机裁剪,数据增大
totensor:将numpy image变成torch image(需要转换轴)
我们会把以上写出可以被调用的类,而不是简单的函数,所以transform的参数不需要在被调用的需要每次都进行传递,我们只需要实现__call__和__init__这两个函数接口

rescale
class Rescale(object):
    def __init__(self,output_size):#output_size可以是int或者元组
        assert isinstance(output_size,(int,tuple))
        '''
        assert expression
        等价于
        if not expression:
        raise AssertionError
        表达式的值为false触发异常
        
        isinstance()判断是否是某类型的实例
        '''
        self.output_size=output_size
    
    def __call__(self,sample):
        image, landmarks = sample['image'], sample['landmarks']
        #在上面dataset里最后sample是字典类型
        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size
        '''
        如果是int类型,若宽度比高度大,即固定高度为该整数,进行放缩
        '''
        new_h, new_w = int(new_h), int(new_w)
        
        img = transform.resize(image, (new_h, new_w))
        
        #对于landmark坐标也进行相应的变换
        landmarks = landmarks * [new_w / w, new_h / h]
        return {'image': img, 'landmarks': landmarks}
randomcrop
class RandomCrop(object):
    '''
    outputsize是整数则正方形剪裁,若是元组则按元组数据进行剪裁
    '''
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, top]

        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x Channel
        # torch image: C X H X W
        #利用transpose更换轴
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}
将定义好的transform应用到samples上

可以通过compose将这些transform组合起来

transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                           root_dir='data/faces/',
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))

自定义dataloder
dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=4)

关于torchvision

torchvision包提供了一些常见的dataset和transform,你不需要去自定义他们,其中imagefolder是其中一种通用的dataset类,它默认图片的组织结构如下:
root/ants/xxx.png,其中ants类标签,另外torchvision包下的的transform操作如Scale等也可以直接使用,使用方法如下

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

待续

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值