pytorch使用教程-基于自定义 Dataloader中的collate_fn 函数 实现变长数据处理

问题背景

想要使用pytorch 框架中的 Dataset 和 Dataloader 类,将变长序列整合为batch数据 (主要是对长短不一的序列进行补齐),通过自定义collate_fn函数,实现对变长数据的处理。

主要思路

Dataset 主要负责读取单条数据,建立索引方式。
Dataloader 负责将数据聚合为batch。

应用实例

测试环境: python 3.6 ,pytorch 1.2.0

数据路径:
在这里插入图片描述
data路径下存储的是待存储的数据样本。
举例:其中的 1.json 样本格式为:
在这里插入图片描述

定义数据集class,进行数据索引

数据集class定义代码:

import os
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
class time_series_dataset(Dataset):
    def __init__(self, data_root):
        """
        :param data_root:   数据集路径
        """
        self.data_root = data_root
        file_list = os.listdir(data_root)
        file_prefix = []
        for file in file_list:
            if '.json' in file:
                file_prefix.append(file.split('.')[0])
        file_prefix = list(set(file_prefix))
        self.data = file_prefix
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        prefix = self.data[index]
        import json
        with open(self.data_root+prefix+'.json','r',encoding='utf-8') as f:
            data_dic=json.load(f)
        feature = np.array(data_dic['feature'])
        length=len(data_dic['feature'])
        feature = torch.from_numpy(feature)
        label = np.array(data_dic['label'])
        label = torch.from_numpy(label)
        sample = {'feature': feature, 'label': label, 'id': prefix,'length':length}
        return sample

数据集实例化:

dataset = time_series_dataset("./data/") # "./data/" 为数据集文件存储路径

基于此数据集的实际数据格式如下:
举例: dataset[0]

{'feature': tensor([17, 14, 16, 18, 14, 16], dtype=torch.int32),
 'label': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
         0], dtype=torch.int32),
 'id': '2',
 'length': 6}

定义collate_fn函数,传入Dataloader类

自定义collate_fn代码

def collate_func(batch_dic):
    from torch.nn.utils.rnn import pad_sequence
    batch_len=len(batch_dic)
    max_seq_length=max([dic['length'] for dic in batch_dic])
    mask_batch=torch.zeros((batch_len,max_seq_length))
    fea_batch=[]
    label_batch=[]
    id_batch=[]
    for i in range(len(batch_dic)):
        dic=batch_dic[i]
        fea_batch.append(dic['feature'])
        label_batch.append(dic['label'])
        id_batch.append(dic['id'])
        mask_batch[i,:dic['length']]=1
    res={}
    res['feature']=pad_sequence(fea_batch,batch_first=True)
    res['label']=pad_sequence(label_batch,batch_first=True)
    res['id']=id_batch
    res['mask']=mask_batch
    return res

说明: mask 字段用以存储变长序列的实际长度,补零的部分记为0,实际序列对应位置记为1。返回数据的格式及包含的字段,根据自己的需求进行定义。
Dataloader实例化调用代码:

train_loader = DataLoader(dataset, batch_size=3, num_workers=1, shuffle=True,collate_fn=collate_func)

完整流程代码

import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
class time_series_dataset(Dataset):
    def __init__(self, data_root):
        """
        :param data_root:   数据集路径
        """
        self.data_root = data_root
        file_list = os.listdir(data_root)
        file_prefix = []
        for file in file_list:
            if '.json' in file:
                file_prefix.append(file.split('.')[0])
        file_prefix = list(set(file_prefix))
        self.data = file_prefix
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        prefix = self.data[index]
        import json
        with open(self.data_root+prefix+'.json','r',encoding='utf-8') as f:
            data_dic=json.load(f)
        feature = np.array(data_dic['feature'])
        length=len(data_dic['feature'])
        feature = torch.from_numpy(feature)
        label = np.array(data_dic['label'])
        label = torch.from_numpy(label)
        sample = {'feature': feature, 'label': label, 'id': prefix,'length':length}
        return sample
def collate_func(batch_dic):
    from torch.nn.utils.rnn import pad_sequence
    batch_len=len(batch_dic)
    max_seq_length=max([dic['length'] for dic in batch_dic])
    mask_batch=torch.zeros((batch_len,max_seq_length))
    fea_batch=[]
    label_batch=[]
    id_batch=[]
    for i in range(len(batch_dic)):
        dic=batch_dic[i]
        fea_batch.append(dic['feature'])
        label_batch.append(dic['label'])
        id_batch.append(dic['id'])
        mask_batch[i,:dic['length']]=1
    res={}
    res['feature']=pad_sequence(fea_batch,batch_first=True)
    res['label']=pad_sequence(label_batch,batch_first=True)
    res['id']=id_batch
    res['mask']=mask_batch
    return res
if __name__ == "__main__":
    dataset = time_series_dataset("./data/")
    batch_size=3
    train_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4, shuffle=True,collate_fn=collate_func)
    for batch_idx, batch in tqdm(enumerate(train_loader),total=int(len(train_loader.dataset) / batch_size) + 1):
        inputs,labels,masks,ids=batch['feature'],batch['label'],batch['mask'],batch['id']
        break

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值