train_loader = DataLoader(dataset, batch_size=3, shuffle=True, collate_fn=default_collate)
此处的collate_fn,是一个函数,会将DataLoader生成的batch进行一次预处理
假设我们有一个Dataset,有input_ids、attention_mask等列:
使用torch创建dataloder时,如果使用默认的collate_fn(default_collate),输出的batch中,input_ids,和token_type_ids,attention_mask都是长度为 sequence_length 的列表(如果input_ids都已经pad到sequence_length),
列表的每个元素是大小为[batch_size]的tensor,
如果input_ids的长度不相等还会报错。
这并不是我们想要的模型的输入格式。
我们希望一个batch中,input_ids的应该是shape=(batch_size,max_seq_length)的tensor
默认情况下:
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate
train_loader = DataLoader(dataset, batch_size=3, shuffle=True, collate_fn=default_collate)
for batch in train_loader:
print(batch)
break
#如下所示,batch_size=3,所以input_ids都是长度为3的tensor组成的列表,这不是我们想要的
{'input_ids': List[torch.Tensor], 'attention_mask': List[torch.Tensor]}
Tensor的shape为(batch_size) 即(3)
这样是无法直接把这个batch输入bert的,必须把input_ids,和token_type_ids,attention_mask都转化为大小为 [ batch_size,sequence_length ] 的tensor才能输入bert。
所以,需要定义自己的collate_fn函数,对batch进行整理(使用torch.stack函数对列表进行拼接):
#定义collate_fn,把input_ids,和token_type_ids,attention_mask列转化为tensor
def collate_fn(examples):
batch = default_collate(examples)
batch['input_ids'] = torch.stack(batch['input_ids'], dim=1)
batch['token_type_ids'] = torch.stack(batch['token_type_ids'], dim=1)
batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=1)
return batch
train_loader = DataLoader(dataset, batch_size=3, shuffle=True, collate_fn=collate_fn)
#输出第一个batch
for batch in train_loader:
print(batch)
break
结果如下,可以直接输入bert模型
output=model(**batch)
batch
{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor}
tensor的shape=(batch_size,max_seq_length),符合bert的输入格式
然而,以上方法必须要求所有input_ids长度一致,否则依然会报错。
其实tokenizer类已经提供了更简单的方法 tokenizer.pad()
def collate_fn_SentenceClassify(features,tokenizer):
#batch = default_collate(examples) 不再需要这句话
#将batch中的input_ids和attention_mask等补齐
batch = tokenizer.pad(
features,
padding=True,
max_length=None,
)
return batch
此处,features是List[Dict[str,Tensor]], batch被整理成了Dict[str,Tensor],符合输入格式。
和DataCollatorWithPadding 效果一样
from transformers import DataCollatorWithPadding
对于更复杂的任务,例如文本生成任务,可以直接调用transformers库中的collate_fn,例如
from transformers import DataCollatorForSeq2Seq
具体用法可以看huggingface官网
更多DataCollator