Pytorch使用collate_fn拼接维度不同的数据LSTM

本文介绍了如何在Pytorch中使用DataLoader的collate_fn参数来处理LSTM模型训练时,数据维度不一致的问题。通过自定义collate_fn函数,实现了将不同序列长度的数据拼接成适合LSTM模型输入的形状,解决了LSTM模型要求固定序列长度输入的挑战。
摘要由CSDN通过智能技术生成

Pytorch使用collate_fn拼接维度不同的数据LSTM

DataLoader有一个参数collate_fn,这个参数接收自定义collate函数,该函数在数据加载(即通过Dataloader取一个batch数据)之前,定义对每个batch数据的处理行为。
看下面的示例:

import torch
from torch.utils.data import Dataset, DataLoader,\
TensorDataset

def collate(data_):
	"""
	data_是一个列表,长度和DataLoader中定义的batch_size相等,
	每一个列表元素为从Dataset采样一次得到的数据,
	比如batch_size为2,从Dataset一次采样的数据为x,y,
	那么data_表示为[(x1,y1),(x2,y2)]。而从DataLoader出来的
	数据是 X=[x1,x2]^T和Y=[y1,y2]^T,
	下面的代码就是将data_变成X和Y的形式。
	"""
	
	x, y = zip(*data_) # zip 可以将多个列表(或元组)的对应元素拼在一起,这样x1和x2就在一个列表里,y1和y2在一个列表里
	x = torch.stack(x) # 把列表变成张量形式,stack默认在维度0拼接,维度大小等于batch_size大小
	y = torch.stack(y)
	return x, y
	
data = torch.rand(100,128)  # 生成x数据
label = torch.randint(0,2, (100,)).float()  # 生成y标签数据
dataset 
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值