Pytorch: dataloader的一些使用心得

Pytorch: Dataloader的一些使用心得

这篇博文不讲原理,只讲一些使用方法和技巧。所有提供的信息仅供参考,不要当作金科玉律。

基本程序框架

首先给出讲述的时候使用的基本程序框架。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

class My_Dataset(Dataset):

    def __init__(self, list1, array2):
        self.len = len(list1)
        self.x_data = list1 # something support indexing, like a list, length = 16
        self.y_data = array2 # something support indexing, like torch.Tensor, shape = (16, 4, 5)

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

# padding unequal length sequences

def collate_fn(batch_data):
    return batch_data

# train dataloader & test dataloader

list1 = [chr(ord('a') + i) for i in range(16)] # 'a'~'p'
array2 = torch.randn((16, 4, 5))

my_dataset = My_Dataset(list1, array2)
my_dataloader = DataLoader(dataset = my_dataset,
                           batch_size = 4,
                           collate_fn = collate_fn)

从dataloader获取数据

注意这个函数:

def __getitem__(self, index):
    return self.x_data[index], self.y_data[index]

这代表,如果你用下标索引i从dataloader中取出值,返回值将会是一个长度为2的元组,下标为0的是list1[i](即第i+1个字母),下标为1的是array2[i](即一个size = (4, 5)的tensor)。暂且称这种形式的数据为data[i]

此时如果你运行如下指令:

for batch_data in enumerate(my_dataloader):
    # show batch_data

batch_data是一个长度为2的元组,下标为0的是这个batch的序号(在以上的程序里面是0~3),下标为1的是一个长度为4(batch_size)的support indexing的对象,这个对象的每个元素就是对应batch中应该包含的几个data[i],比如第0个batch的这个列表中的元素就分别是data[0],..data[3]。至于data[i]则是刚才说的由两项数据所构成的元组。
在这里,下标为1的对象是一个列表。而如果数据本身就是一个tensor的话,这里会给一个第一维维度为batch_size,其他维维度数对应的tensor.

此时如果你运行如下指令:

for batch_index, batch_data in enumerate(train_loader):
    # show data

这里的batch_index对应元组的下标为0的元素,即这个batch的序号(在以上的程序里面是0~3);batch_data对应上面的列表(support indexing的对象)。显然这种更细致的处理是更常用的。

对于以上讲的两点,读者可以直接跑一下附录1所示的程序来获得直观感受。

collate_fn的使用

在从dataloader中读取数据时,可以通过collate_fn做处理,使读取的数据符合要求。

让我们审视这个函数:

def collate_fn(batch_data):
    return batch_data

这里输入的batch_data就是上一节那个以batch_size为长度,以对应位置的data[i]为元素的列表。如果要取得元素之后进行特定处理,可以在这个函数里面操作;这个函数的返回值会代替原来那个列表的位置。可以运行附录2的代码获得直观感受。

collate_fn的使用实例

在自然语言处理中,可能要把不等长的tensor padding 成等长,这个步骤可以在collate_fn里面做。举个例子,下面的这个函数从不等长Tensor的列表生成一个padding成等长的高维tensor.

def collate_fn(data):
    # self.data: list of tensors of different length
    # data:[x[0], x[1], ..], x[0].shape = (20, 128), x[1].shape = (30, 128)
    #                        x[2].shape = (28, 128), x[3].shape = (25, 128)
    data.sort(key=lambda data: len(data[0]), reverse=True) # 按照序列长度降序排列
    seq_len_list = [elem.shape[0] for elem in data]
    data = pad_sequence(data, batch_first=True, padding_value=0)
    seq_len_list = torch.Tensor(seq_len_list)
    return data_batch, seq_len_list
# data_batch.shape = [4, 30, 128], seq_len_list = [20, 30, 28, 25]

函数的返回值包括合并的高维tensor和每个小tensor的实际长度,方便后续处理使用。

附录

附录1

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

torch.manual_seed(314)

class My_Dataset(Dataset):

    def __init__(self, list1, array2):
        self.len = len(list1)
        self.x_data = list1 # something support indexing, like a list, length = 16
        self.y_data = array2 # something support indexing, like torch.Tensor, shape = (16, 4, 5)

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

# padding unequal length sequences

def collate_fn(batch_data):
    return batch_data

# train dataloader & test dataloader

list1 = [chr(ord('a') + i) for i in range(16)] # 'a'~'p'
array2 = torch.randn((16, 4, 5))

my_dataset = My_Dataset(list1, array2)
my_dataloader = DataLoader(dataset = my_dataset,
                           batch_size = 4,
                           collate_fn = collate_fn)


for batch_data in enumerate(my_dataloader):
    # show batch_data
    print("New Batch")
    print(type(batch_data), len(batch_data), batch_data[0], type(batch_data[1]))
    print(len(batch_data[1]), type(batch_data[1][0]))
    print(batch_data[1][0][0], type(batch_data[1][0][1]), batch_data[1][0][1].shape)

for batch_index, batch_data in enumerate(my_dataloader):
    # show batch_data
    print("Batch", batch_index)
    for i in range(len(batch_data)):
        print(type(batch_data[i]), len(batch_data[i]))
        print(batch_data[i][0], type(batch_data[i][1]), batch_data[i][1].shape)

附录2

...

my_dataset = My_Dataset(list1, array2)
my_dataloader = DataLoader(dataset = my_dataset,
                           batch_size = 4,
                           collate_fn = collate_fn)

for batch_index, batch_data in enumerate(my_dataloader):
    # show batch_data
    print("Batch", batch_index)
    print(batch_data)
  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值