使用方法
作为dataLoader的形参,不传入的时候使用默认的,可以自己定义。
DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
自己定义:
def collate_fn(examples):
"""
wfj:该函数表示对于batch_size中的每一个元素做以下一下的操作,通常用来进行数据的标准化工作
"""
print("==========================")
print(examples)
print(len(examples))
lengths = torch.tensor([len(ex[0]) for ex in examples])
inputs = [torch.tensor(ex[0]) for ex in examples]
targets = [torch.tensor(ex[1]) for ex in examples]
# 对batch内的样本进行padding,使其具有相同长度
inputs = pad_sequence(inputs, batch_first=True, padding_value=vocab["<pad>"])
targets = pad_sequence(targets, batch_first=True, padding_value=vocab["<pad>"])
#输出的几个参数的解释:解释变量;每个解释变量的长度;被解释变量;是否为填充位的标记。
return inputs, lengths, targets, inputs != vocab["<pad>"]
打印信息
我们的batch_size设置的是32。
解析
所以collate_fn接受的一个参数,就是Dataloader迭代取出的每个batch_size,我们可以在collate_fn中对每个batch_size的数据进行相关的操作和个性化的处理。