transformers DataCollatorForPermutationLanguageModeling类

构造方法
DataCollatorForPermutationLanguageModeling( 
	tokenizer: PreTrainedTokenizerBase,
	plm_probability: float = 0.16666666666666666,
	max_span_length: int = 5,
	return_tensors: str = 'pt')

用于特指XLNet进行预训练PLM(排列组合语言模型)时使用的数据收集器,关于XLNet以及PLM,可以参考这里

参数tokenzier表示编码句子的分词器。

参数plm_probability表示部分预测的序列末尾被mask的概率。序列末尾被mask的长度=span_length*plm_probability。

参数max_span_length表示span_length的采样区间的长度[1, max_span_length]。

参数return_tensors表示返回数据的类型,有三个可选项,分别是"tf"、“pt”、“np”,分别表示tensorflow可以处理的数据类型,pytorch可以处理的数据类型以及numpy数据类型。

使用示例
def preprocess_fn(data):
    data = {k: sum(data[k], []) for k in data.keys()}
    length = len(data["input_ids"]) // 128 * 128
    result = {k: [v[i: i + 128] for i in range(0, length, 128)] for k, v in data.items()}
    result["labels"] = result["input_ids"].copy()
    return result


dataset = datasets.load_dataset("wikitext", "wikitext-2-raw-v1")
tokenizer = transformers.AutoTokenizer.from_pretrained("xlnet-base-cased")
data_collator = transformers.DataCollatorForPermutationLanguageModeling(tokenizer=tokenizer,
                                                             			return_tensors="tf")
dataset = dataset.map(lambda data: tokenizer(data["text"], truncation=True),
                      batched=True,
                      batch_size=1000,
                      remove_columns=["text"])
dataset = dataset.map(preprocess_fn,
                      batched=True,
                      batch_size=1000)
train_dataset = dataset["train"].to_tf_dataset(columns=["input_ids", "attention_mask", "labels"],
                                               batch_size=16,
                                               shuffle=True,
                                               collate_fn=data_collator)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

不负韶华ღ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值