获取一个batch数据的步骤
1,首先我们要确定数据集的长度n, 从dataset的__len__方法中可以得到这个值。
结果类似:n = 1000。
2,然后我们从0到n-1的范围中抽样出m个数(batch大小),由 DataLoader的 sampler和 batch_sampler参数进行抽样。
假定m=4, 拿到的结果是一个列表,类似:indices = [1,4,8,9]
3,接着我们从数据集中去取这m个数对应下标的元素,由 Dataset的 __getitem__方法实现的
拿到的结果是一个元组列表,类似:samples = [(X[1],Y[1]),(X[4],Y[4]),(X[8],Y[8]),(X[9],Y[9])]
4,最后我们将结果整理成两个张量作为输出,由 DataLoader的 collate_fn参数指定的
拿到的结果是两个张量,类似batch = (features,labels),
其中 features = torch.stack([X[1],X[4],X[8],X[9]])
labels = torch.stack([Y[1],Y[4],Y[8],Y[9]])
代码示例
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
def collate_fn(batch):
text, label = zip(*batch)
text = list(text)
label = list(label)
for i in range(len(text)):
text[i] = text[i] + '啧啧啧'
label[i] = label[i] + 2
return text, torch.tensor(label)
class MyDataset(Dataset):
def __init__(self, data):
self.data = pd.read_csv(data, encoding='gbk') # 进行初始化,这里是拿到所有的数据
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
text, label = self.data.iloc[idx, 0], self.data.iloc[idx, 1] # 这里是根据idx来返回数据,具体怎么返回根据自己的需求来。比如我的数据是读取的一个csv数据,我就需要用".iloc[x,y]"的索引方式来获取数据
return text, label
datasets = MyDataset('test-data.csv')
dataloader = DataLoader(datasets, batch_size=2, shuffle=True, collate_fn=collate_fn) # 初始化好datasets之后,就可以放到DataLoader中了,dataloader常用的参数有: dataset, batch_size, shuffle, num_workers, drop_last, collate_fn等
for i, (text, label) in enumerate(dataloader):
print(i)
print(text)
print(label)
代码解读
通过获取一个batch数据的讲解,加上代码中的注释基本就能搞明白是怎么组织数据和获取数据的了,这里额外讲解一下dataloader中的collate_fn参数的作用。
这个函数其实就是对获取到的一个batch的数据进行处理,比如将数字转换为tensor格式等,当然我们也可以在这个函数中对读取到的数据做额外的操作,比如下面展示了设置和不设置collate_fn参数时打印的内容:
# test-data.csv中的数据如下:
text,labels
你好,1
你知道,1
我是谁,0
我是你,1
困困困,0
进进进,1
哈哈哈,0
不说了,0
再也不,1
你就是,0
# 不设置collate_fn参数时打印的内容
0
('困困困', '我是你')
tensor([0, 1])
1
('你就是', '我是谁')
tensor([0, 0])
2
('进进进', '哈哈哈')
tensor([1, 0])
3
('不说了', '你好')
tensor([0, 1])
4
('你知道', '再也不')
tensor([1, 1])
# 设置collate_fn参数时打印的内容,可以看到每个text的内容都被加了"啧啧啧",每个label的数字都加了2。
# 此外原始的元组存储text的方式也变成了列表存储的的方式,这些都是在collate_fn函数中进行操作的
# 当然,这些操作也可以在__getitem__获取数据的时候操作
0
['你知道啧啧啧', '困困困啧啧啧']
tensor([3, 2])
1
['我是谁啧啧啧', '进进进啧啧啧']
tensor([2, 3])
2
['你好啧啧啧', '不说了啧啧啧']
tensor([3, 2])
3
['哈哈哈啧啧啧', '我是你啧啧啧']
tensor([2, 3])
4
['再也不啧啧啧', '你就是啧啧啧']
tensor([3, 2])