【Deep learning with pytorch 自学教程】Dataset&Dataloder

【Deep learning with pytorch 自学教程】数据集处理:Dataset&Dataloder 的使用

前言

在看过了许许多多的科研大佬的论文之后,发现pytorch 被越来越多人使用喜爱,于是也想尝试使用pytorch开始我的科研之路。这一系列博客主要参考pytorch官网自带的教程,并模仿教程中的代码实现建议对照着网站的教程看这篇博客。

数据集

数据集的读取与预处理对于深度学习是必不可少的部分。在tensorflow中数据的读取需要转化为tfrecord形式的文件进行读取,而在pytorch中将数据的读取打包为一个类(Dataset Class)在该类的基础上可以拓展子父类,实现对数据集的预处理和迭代读取。

打包数据集(Dataset Class)

torch.utils.data.Dataset是一个代表数据集的抽象类,想要为自己的神经网络打包一个数据集可以继承这个抽象类,并需要重写这个类中以下几个函数:

__len__函数:len(dataset)可以返回数据集的大小;
__getitem__函数: 实现通过下表读取数据集,如:dataset[i]返回数据集中第i个数据

打包数据集可以这样安排,在_ _init _ _函数中写入数据的读取路径与读取方式(如csv文件可以通过pd.read_csv(file)读取,npy文件可以通过np.load(file)读取,但在这里还没有真正的读取,只是定义读取的方式,真正的读取在getitem之中),在__getitem__函数中将具体的文件读出来。这样可以节省内存,不需要将所有的数据都放在内存里,需要哪一个读哪一个就好了。在getitem函数之中,返回的数据形式可以自定义,比如在给出的例子中,可以定义为{‘image’:image,‘landmarks’:landmark}这样的形式。
具体代码:

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset.
    这是一个标记人脸特殊点的数据集,数据的存储方式为.csv
    数据的存储形式为:
    图片名称  标记点1_x  标记点1_y  标记点2_x  标记点2_y  ......
    """

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)#定义数据集的读取方式
        self.root_dir = root_dir
        self.transform = transform #这里是对数据集的预处理,后面会讲到

    def __len__(self):
        return len(self.landmarks_frame)#定义len函数,返回数据集的大小

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()#将索引转化为链表的形式

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name) #读取图片
        landmarks = self.landmarks_frame.iloc[idx, 1:] #读取所有的标记点
        landmarks = np.array([landmarks])
        landmarks = landmarks.astype('float').reshape(-1, 2)#将每一个点的横纵坐标列为一行
        sample = {'image': image, 'landmarks': landmarks}#定义数据返回的形式

        if self.transform:
            sample = self.transform(sample)

        return sample

具体的使用如下:

face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                    root_dir='data/faces/')#实例化之前定义的类

fig = plt.figure()

for i in range(len(face_dataset)):#运用len函数得到整个数据集的长度
    sample = face_dataset[i]#通过下标的形式调用getitem函数

    print(i, sample['image'].shape, sample['landmarks'].shape)#这里返回的是形式和之前定义的一样

    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)

    if i == 3:
        plt.show()
        break

结果是:
0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)

数据集图片(直接从教程中截图)

预处理(Transforms)

数据的预处理在之前打包数据集的时候也有提到,数据的预处理可以包括如图片的尺度变化,数据增强,将图片转换为torch的形式。
为了不每次都将数据预处理所需要的参数都传入到函数中,一般将数据的预处理写作一个可以调用的类的形式。运用__call__函数和__init__函数实现,其中__init__函数可以定义预处理的所需要的参数,__call__函数可以用来传入数据,调用处理函数。

tsfm = Transform(params) #实例化预处理方式
transformed_sample = tsfm(sample) #传入数据,得到预处理之后的transformed_sample

具体实现的预处理函数如下:

class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值