pytorch tutorial: Data loading and processing

本文介绍了如何在PyTorch中创建自定义数据集,包括自定义transform、DataLoader的使用,以及一些实用的辅助函数。通过实例展示了如何加载和处理数据,包括图片的预处理和批量处理。
摘要由CSDN通过智能技术生成

前言

pytorch用了一段时间,后来又不用了,所以重新捡起pytorch再学一遍。

本次的资料是来自官方教程pytorch自定义数据导入和处理教程,教程中使用的是skimage处理,在我的实现中,我使用了PIL处理图片。

Self-made dataset (使用Dataset自定义数据集)


通过继承torch.utils.data.Dataset,我们可以实现自定义的数据集。继承之后,必须覆盖以下两种方法:

  • __len__: 获得数据集的大小
  • __getitem__: 支持下标索引访问,如dataset[i]

_init__构造函数里预加载文件,在__len__中定义大小, 在__getitem__中定义用idx获得单个样本的方法。

以下是使用PIL和pandas加载数据集的例子,样例来自官网


class LandmarkFacesDataset(Dataset):

    def __init__(self, csv_file, root, transform=None):
        """
        :params csv_file(string): Path to the csv file with annotations
        :params root(string): Path to images
        :params transform(callable, optional): Optional transform ot be applied
            on a sample
        """
        super(LandmarkFacesDataset, self).__init__()
        self.root = root
        self.data = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        assert idx >= 0 and idx < self.__len__(), "IndexError: idx must >= {} and < {}".format(0, self.__len__())
        name = self.data.iloc[idx, 0]
        pixels = self.data.iloc[idx, 1:].values.reshape(-1, 2)
        # load image array
        image_arr = np.array(Image.open(self.root + name))
        
        sample = {
   "images":image_arr, "landmarks":pixels}

        if self.transform is not None:
            sample = self.transform(sample)

        return sample

注意

  • 单个sample是dict类型,包括一个样本包含的所有必须的信息。
  • __getitem__函数只返回对应idx的一个样本,在DataLoader内部暂时我们不需要关心怎么调用的。

Self-made transform(自定义transform)


自定义的transform其实是基于其他的图像处理库,比如说PIL.Image,scikit-image(skimage),openCV,一般它们将图片和numpy.ndarray类型进行互转。

定义的transform除了必要的时候需要改写构造函数外,重写__call__方法,初始化后就可以直接像函数一样调用了。同时,这也使得它可以传入transforms.Compose()作为参数,对图片进行连续处理。

如下是一个对图片进行放缩的函数,该函数必须 特别注意长宽,不要弄反了

class Rescale1(object):
    """对图片进行放缩或者调整大小,
        注意长宽在选择的时候是反的
    """
    def __init__(self, output_size):

        assert(isinstance(output_size, (int, tuple)))
        self.output_size = output_size

    def __call__(self, sample):

        image, landmarks = sample["images"], sample["landmarks"]

        h,w = image.shape[:2]
        if isinstance(self.output_size, int):

            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h

        else:
            new_h, new_w = self.output_size 

        new_h, new_w = int(new_h), int(new_w)

        img = Image.fromarray(image)
        img = img.resize((new_w, new_h))
        img = np.array(img, dtype=image.dtype)

        landmarks = landmarks * np.array([new_w / w, new_h / h])

        return {
   "images":img, "landmarks": landmarks}

以下是随机裁减和转化为tensor的变换:

class RandomCrop(object):
    """对图片进行随机裁减
    """
    def __init__(self, output_size):

        
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值