pytorch长度不同的数据如何放在一个batch

RNN及其变种算法处理一维信号经常会遇到信号长度不一致的问题。

from torch.utils.data import Dataloader
dataloader = Dataloader(dataset, batch_size=8)

这样是没法成功加载dataset,因为Dataloader要求一个batch内的数据shape是一致的,才能打包成一个方块投入模型。我们看一下源码里Dataloader初始化的方法

def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=None,
                 pin_memory=False, drop_last=False, timeout=0,

其中collate_fn是pytorch为我们提供的数据裁剪函数,当collate_fn=None时,初始化会调用默认的裁剪方式即直接将数据打包,所以这时如果数据shape不一致,会打包不成功。因此我们需要自己写一个collate_fn函数,我常用的两种方式是:1.将所有数据截断到和最短的数据一样长;2.将所有的数据补零到和最长的数据一样长。
这里给出第一种方法的实现方式,第二种稍微复杂一点,也不难:

import torch
def collate_fn(data):  # 这里的data是一个list, list的元素是元组,元组构成为(self.data, self.label)
	# collate_fn的作用是把[(data, label),(data, label)...]转化成([data, data...],[label,label...])
	# 假设self.data的一个data的shape为(channels, length), 每一个channel的length相等,data[索引到数据index][索引到data或者label][索引到channel]
    data.sort(key=lambda x: len(x[0][0]), reverse=False)  # 按照数据长度升序排序
    data_list = []
    label_list = []
    min_len = len(data[0][0][0]) # 最短的数据长度 
    for batch in range(0, len(data)): #
        data_list.append(data[batch][0][:, :min_len])
        label_list.append(data[batch][1])
    data_tensor = torch.tensor(data_list, dtype=torch.float32)
    label_tensor = torch.tensor(label_list, dtype=torch.float32)
    data_copy = (data_tensor, label_tensor)
    return data_copy

使用方法也很简单

dataloader = torch.utils.data.Dataloader(dataset, collate_fn=collate_fn)
  • 9
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值