Pytorch使用collate_fn拼接维度不同的数据LSTM

Pytorch使用collate_fn拼接维度不同的数据LSTM

DataLoader有一个参数collate_fn,这个参数接收自定义collate函数,该函数在数据加载(即通过Dataloader取一个batch数据)之前,定义对每个batch数据的处理行为。
看下面的示例:

import torch
from torch.utils.data import Dataset, DataLoader,\
TensorDataset

def collate(data_):
	"""
	data_是一个列表,长度和DataLoader中定义的batch_size相等,
	每一个列表元素为从Dataset采样一次得到的数据,
	比如batch_size为2,从Dataset一次采样的数据为x,y,
	那么data_表示为[(x1,y1),(x2,y2)]。而从DataLoader出来的
	数据是 X=[x1,x2]^T和Y=[y1,y2]^T,
	下面的代码就是将data_变成X和Y的形式。
	"""
	
	x, y = zip(*data_) # zip 可以将多个列表(或元组)的对应元素拼在一起,这样x1和x2就在一个列表里,y1和y2在一个列表里
	x = torch.stack(x) # 把列表变成张量形式,stack默认在维度0拼接,维度大小等于batch_size大小
	y = torch.stack(y)
	return x, y
	
data = torch.rand(100,128)  # 生成x数据
label = torch.randint(0,2, (100,)).float()  # 生成y标签数据
dataset = TensorDataset(data,label)  # 构建数据集
loader = DataLoader(dataset, batch_size=32, collate_fn=collate)  # 构建加载器,collate_fn就是在这一步进行处理
X,Y = next(iter(loader))  # 得到一个批次的数据
print(X.shape, Y.shape)  # 和预期一致  
# Output: (32,128),(32)

除了上面的简单情形,在写训练代码的时候,碰到了下面的情况:
当时想使用 LSTM模型进行深度学习训练,输入数据源为时空数据,提取之后的形状为(node_num, seq_len, n_feat), node_num表示节点数,seq_len表示序列长度,n_feat为节点特征维度。

如果直接进行数据加载(DataLoader不通过collate_fn处理),输出的shape将是(batch_size, node_num, seq_len, n_feat)

而LSTM模型需要的输入形状是(seq_len, batch_size, n_feat), 显然不匹配。刚开始想到通过以下方式解决:

# [batch_size, node_num, seq_len, n_feat] 
# -> [batch_size*node_num, seq_len, n_feat]
# -> [seq_len, batch_size*node_num, n_feat]
x.reshape(-1,seq_len,n_feat).permute(1,0,2) 

但这样其实不行,因为每次采样的数据 node_num 不相等, 想想之前的第一段代码,如果x1 的shape是[5,12,100], x2[7,12,100], 通过torch.stack 无法拼接, DataLoader默认的行为和上面代码是一样的。如果强行那么做会报错,类似于这种:
invalid argument 0: Sizes of tensors must match except in dimension 0. Got 7 and 5 in dimension 1

在这种情况下,需要通过collate_fn自定义处理行为,看下面的代码:

import torch
from torch.utils.data import Dataset, DataLoader,\
TensorDataset
import numpy as np

class TestDataset(Dataset):
    def __init__(self, n=100):
        super(TestDataset,self).__init__()
        # 定义node_num不同的两组数据
        self.data1 = torch.rand(n,5,12,128) 
        self.data2 = torch.rand(n,7,12,128)
        
    def __len__(self):
        return len(self.data1)
    
    def __getitem__(self,indx):
    	# 随机选择一组数据,得到的x可能出现node_num不同的情况
        if torch.rand(1) > 0.5:
            return self.data2[indx],torch.tensor(1)
        else:
            return self.data1[indx],torch.tensor(0)

def collate(data):
    x, y = zip(*data)
    # 统计每个x的node_num
    sz = [len(x) for x in x]
    # 计算偏移量
    cum_offsets = [0] + np.cumsum(sz).tolist()
    start_end = [[start, end] for start, end in zip(cum_offsets, cum_offsets[1:])]
    # 在node_num维度拼接,变成 [seq_len, node_num_total, n_feat]
    x = torch.cat(x, dim=0).permute(1,0,2)
    y = torch.stack(y)
    # 为了能从x恢复原来的数据,需要返回偏移量
    start_end = torch.LongTensor(start_end)
    return x, start_end, y

tdataset = TestDataset()
loader = DataLoader(tdataset, batch_size=32, collate_fn=collate)
X, offsets, Y = next(iter(loader))
print(X.shape, offsets.shape, Y.shape)  # 和预期一致
# Output: (12,164,128), (32,2), (32)

通过以上代码的处理,直接就得到了目标shape seq_len, batch_size, n_feat,不过这个batch_size等于batch_size个node_num 相加,在node_num相等的情况下,这种方式和reshape+permute的方式是等效的。

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值