Pytorch使用collate_fn拼接维度不同的数据LSTM
DataLoader有一个参数collate_fn,这个参数接收自定义collate
函数,该函数在数据加载(即通过Dataloader取一个batch数据)之前,定义对每个batch数据的处理行为。
看下面的示例:
import torch
from torch.utils.data import Dataset, DataLoader,\
TensorDataset
def collate(data_):
"""
data_是一个列表,长度和DataLoader中定义的batch_size相等,
每一个列表元素为从Dataset采样一次得到的数据,
比如batch_size为2,从Dataset一次采样的数据为x,y,
那么data_表示为[(x1,y1),(x2,y2)]。而从DataLoader出来的
数据是 X=[x1,x2]^T和Y=[y1,y2]^T,
下面的代码就是将data_变成X和Y的形式。
"""
x, y = zip(*data_) # zip 可以将多个列表(或元组)的对应元素拼在一起,这样x1和x2就在一个列表里,y1和y2在一个列表里
x = torch.stack(x) # 把列表变成张量形式,stack默认在维度0拼接,维度大小等于batch_size大小
y = torch.stack(y)
return x, y
data = torch.rand(100,128) # 生成x数据
label = torch.randint(0,2, (100,)).float() # 生成y标签数据
dataset