RNN及其变种算法处理一维信号经常会遇到信号长度不一致的问题。
from torch.utils.data import Dataloader
dataloader = Dataloader(dataset, batch_size=8)
这样是没法成功加载dataset,因为Dataloader要求一个batch内的数据shape是一致的,才能打包成一个方块投入模型。我们看一下源码里Dataloader初始化的方法
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
其中collate_fn是pytorch为我们提供的数据裁剪函数,当collate_fn=None时,初始化会调用默认的裁剪方式即直接将数据打包,所以这时如果数据shape不一致,会打包不成功。因此我们需要自己写一个collate_fn函数,我常用的两种方式是:1.将所有数据截断到和最短的数据一样长;2.将所有的数据补零到和最长的数据一样长。
这里给出第一种方法的实现方式,第二种稍微复杂一点,也不难:
import torch
def collate_fn(data): # 这里的data是一个list, list的元素是元组,元组构成为(self.data, self.label)
# collate_fn的作用是把[(data, label),(data, label)...]转化成([data, data...],[label,label...])
# 假设self.data的一个data的shape为(channels, length), 每一个channel的length相等,data[索引到数据index][索引到data或者label][索引到channel]
data.sort(key=lambda x: len(x[0][0]), reverse=False) # 按照数据长度升序排序
data_list = []
label_list = []
min_len = len(data[0][0][0]) # 最短的数据长度
for batch in range(0, len(data)): #
data_list.append(data[batch][0][:, :min_len])
label_list.append(data[batch][1])
data_tensor = torch.tensor(data_list, dtype=torch.float32)
label_tensor = torch.tensor(label_list, dtype=torch.float32)
data_copy = (data_tensor, label_tensor)
return data_copy
使用方法也很简单
dataloader = torch.utils.data.Dataloader(dataset, collate_fn=collate_fn)