batchify_fn = lambda samples, fn=Dict({
'input_ids': Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'), # input
'token_type_ids': Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'), # segment
'seq_len': Stack(dtype='int64'),
}): fn(samples)
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'), # input_ids
Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'), # token_type_ids
Stack(dtype='int64'), # seq_len
): fn(samples)
记住这个函数,返回的都是tuple,只是针对输入有tuple类型的输入或者dict类型的输入,相关的函数做了封装