collate_fn函数可以在取出数值的时候对每次进行前面补充0的操作,很方便每一个批次对于数据进行操作
首先定义对应的MyDataset类
import torch
from torch.utils.data import Dataset, DataLoader
# 1. 用Dataset封装数据集,仅做示范,实际可直接用TensorDataset封装
class MyDataset(Dataset):
def __init__(self, x, y):
#assert x.size(0)==y.size(0)
self.x, self.y = x, y
#定义初始化变量
def __getitem__(self, idx):
#print('__getitem__')
return (self.x[idx], self.y[idx])
#定义每次取出的对应数值
def __len__(self):
return len(self.x)
#定义tensor的总长度
# 2. 用DataLoader定义数据批量迭代器
接下来定义collate_fn函数为每次取出的数值的操作,这里之所以collate_fn函数最终决定而不是由__getitem__函数决定,是因为__getitem__函数在collate_fn函数之前运行,
这里面的__getitem__函数可以视为每一次取出数据的时候对数据进行的操作,而collate_fn函数是针对于每一个批次取出数据进行的操作
放入对应的参数,这里使用collate_fn函数让每一个批次以这个批次数值的最大长度作为分界,这个批次如果达不到最大长度的需要在前面的位置补充零。
test1 = [[3],[1,2],[5,6,7],[9]]
test2 = [[6],[4,5],[7,8,9],[10]]
def collate_fn(data):
#这里的collate_fn是对相应的数据进行处理
#找出最大长度的数组,如果其他数组达不到相应的长度,
#就在句子的前面位置补充0
print('function collate_fn')
def _pad_sequences(seqs):
seqs_0 = []
seqs_1 = []
for seq in seqs:
seqs_0.append(seq[0])
seqs_1.append(seq[1])
lens0 = [len(seq) for seq in seqs_0]
lens1 = [len(seq) for seq in seqs_1]
max_len0 = max(lens0)
max_len1 = max(lens1)
max_len = max(max_len0,max_len1)
padded_seqs0 = torch.zeros(len(seqs_0),max_len).long()
padded_seqs1 = torch.zeros(len(seqs_1),max_len).long()
for i in range(len(seqs_0)):
seq = seqs_0[i]
start = max_len-lens0[i]
padded_seqs0[i,start:] = torch.Tensor(seq)
for i in range(len(seqs_1)):
seq = seqs_1[i]
start = max_len-lens1[i]
padded_seqs1[i,start:] = torch.Tensor(seq)
return padded_seqs0,padded_seqs1
data1,data2 = _pad_sequences(data)
#print('data = ')
#print(data)
return data1,data2
dataset = MyDataset(test1,test2)
MyDataLoader = DataLoader(dataset=dataset,shuffle=True,batch_size=4,collate_fn=collate_fn)
for data_iter1,data_iter2 in MyDataLoader:
print('data_iter1 = ')
print(data_iter1)
print('data_iter2 = ')
print(data_iter2)
#print('data_iter2 = ')
#print(data_iter2)
最终MyDataLoader循环之中每次取出来的值由collate_fn的返回值return data1,data2决定