构造方法
DataCollatorForPermutationLanguageModeling(
tokenizer: PreTrainedTokenizerBase,
plm_probability: float = 0.16666666666666666,
max_span_length: int = 5,
return_tensors: str = 'pt')
用于特指XLNet进行预训练PLM(排列组合语言模型)时使用的数据收集器,关于XLNet以及PLM,可以参考这里。
参数tokenzier表示编码句子的分词器。
参数plm_probability表示部分预测的序列末尾被mask的概率。序列末尾被mask的长度=span_length*plm_probability。
参数max_span_length表示span_length的采样区间的长度[1, max_span_length]。
参数return_tensors表示返回数据的类型,有三个可选项,分别是"tf"、“pt”、“np”,分别表示tensorflow可以处理的数据类型,pytorch可以处理的数据类型以及numpy数据类型。
使用示例
def preprocess_fn(data):
data = {k: sum(data[k], []) for k in data.keys()}
length = len(data["input_ids"]) // 128 * 128
result = {k: [v[i: i + 128] for i in range(0, length, 128)] for k, v in data.items()}
result["labels"] = result["input_ids"].copy()
return result
dataset = datasets.load_dataset("wikitext", "wikitext-2-raw-v1")
tokenizer = transformers.AutoTokenizer.from_pretrained("xlnet-base-cased")
data_collator = transformers.DataCollatorForPermutationLanguageModeling(tokenizer=tokenizer,
return_tensors="tf")
dataset = dataset.map(lambda data: tokenizer(data["text"], truncation=True),
batched=True,
batch_size=1000,
remove_columns=["text"])
dataset = dataset.map(preprocess_fn,
batched=True,
batch_size=1000)
train_dataset = dataset["train"].to_tf_dataset(columns=["input_ids", "attention_mask", "labels"],
batch_size=16,
shuffle=True,
collate_fn=data_collator)