问题背景
想要使用pytorch 框架中的 Dataset 和 Dataloader 类,将变长序列整合为batch数据 (主要是对长短不一的序列进行补齐),通过自定义collate_fn函数,实现对变长数据的处理。
主要思路
Dataset 主要负责读取单条数据,建立索引方式。
Dataloader 负责将数据聚合为batch。
应用实例
测试环境: python 3.6 ,pytorch 1.2.0
数据路径:
data路径下存储的是待存储的数据样本。
举例:其中的 1.json 样本格式为:
定义数据集class,进行数据索引
数据集class定义代码:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
class time_series_dataset(Dataset):
def __init__(self, data_root):
"""
:param data_root: 数据集路径
"""
self.data_root = data_root
file_list = os.listdir(data_root)
file_prefix = []
for file in file_list:
if '.json' in file:
file_prefix.append(file.split('.')[0])
file_prefix = list(set(file_prefix))
self.data = file_prefix
def __len__(self):
return len(self.data)
def __getitem__(self, index):
prefix = self.data[index]
import json
with open(self.data_root+prefix+'.json','r',encoding='utf-8') as f:
data_dic=json.load(f)
feature = np.array(data_dic['feature'])
length=len(data_dic['feature'])
feature = torch.from_numpy(feature)
label = np.array(data_dic['label'])
label = torch.from_numpy(label)
sample = {'feature': feature, 'label': label, 'id': prefix,'length':length}
return sample
数据集实例化:
dataset = time_series_dataset("./data/") # "./data/" 为数据集文件存储路径
基于此数据集的实际数据格式如下:
举例: dataset[0]
{'feature': tensor([17, 14, 16, 18, 14, 16], dtype=torch.int32),
'label': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
0], dtype=torch.int32),
'id': '2',
'length': 6}
定义collate_fn函数,传入Dataloader类
自定义collate_fn代码
def collate_func(batch_dic):
from torch.nn.utils.rnn import pad_sequence
batch_len=len(batch_dic)
max_seq_length=max([dic['length'] for dic in batch_dic])
mask_batch=torch.zeros((batch_len,max_seq_length))
fea_batch=[]
label_batch=[]
id_batch=[]
for i in range(len(batch_dic)):
dic=batch_dic[i]
fea_batch.append(dic['feature'])
label_batch.append(dic['label'])
id_batch.append(dic['id'])
mask_batch[i,:dic['length']]=1
res={}
res['feature']=pad_sequence(fea_batch,batch_first=True)
res['label']=pad_sequence(label_batch,batch_first=True)
res['id']=id_batch
res['mask']=mask_batch
return res
说明: mask 字段用以存储变长序列的实际长度,补零的部分记为0,实际序列对应位置记为1。返回数据的格式及包含的字段,根据自己的需求进行定义。
Dataloader实例化调用代码:
train_loader = DataLoader(dataset, batch_size=3, num_workers=1, shuffle=True,collate_fn=collate_func)
完整流程代码
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
class time_series_dataset(Dataset):
def __init__(self, data_root):
"""
:param data_root: 数据集路径
"""
self.data_root = data_root
file_list = os.listdir(data_root)
file_prefix = []
for file in file_list:
if '.json' in file:
file_prefix.append(file.split('.')[0])
file_prefix = list(set(file_prefix))
self.data = file_prefix
def __len__(self):
return len(self.data)
def __getitem__(self, index):
prefix = self.data[index]
import json
with open(self.data_root+prefix+'.json','r',encoding='utf-8') as f:
data_dic=json.load(f)
feature = np.array(data_dic['feature'])
length=len(data_dic['feature'])
feature = torch.from_numpy(feature)
label = np.array(data_dic['label'])
label = torch.from_numpy(label)
sample = {'feature': feature, 'label': label, 'id': prefix,'length':length}
return sample
def collate_func(batch_dic):
from torch.nn.utils.rnn import pad_sequence
batch_len=len(batch_dic)
max_seq_length=max([dic['length'] for dic in batch_dic])
mask_batch=torch.zeros((batch_len,max_seq_length))
fea_batch=[]
label_batch=[]
id_batch=[]
for i in range(len(batch_dic)):
dic=batch_dic[i]
fea_batch.append(dic['feature'])
label_batch.append(dic['label'])
id_batch.append(dic['id'])
mask_batch[i,:dic['length']]=1
res={}
res['feature']=pad_sequence(fea_batch,batch_first=True)
res['label']=pad_sequence(label_batch,batch_first=True)
res['id']=id_batch
res['mask']=mask_batch
return res
if __name__ == "__main__":
dataset = time_series_dataset("./data/")
batch_size=3
train_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4, shuffle=True,collate_fn=collate_func)
for batch_idx, batch in tqdm(enumerate(train_loader),total=int(len(train_loader.dataset) / batch_size) + 1):
inputs,labels,masks,ids=batch['feature'],batch['label'],batch['mask'],batch['id']
break