Pytorch DataLoader用法
from torch.utils.data.dataloader import DataLoader
官方链接
- 搭配 dataset 使用 collect_fn 实现 格式控制
- 自己定义list类型,可以直接调用DataLoader去变list为dict的Tensor,但是一定要 default_data_collator
DataLoader(d_samples, batch_size=batch_size, collate_fn=default_data_collator)
源码:
def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
"""
Very simple data collator that simply collates batches of dict-like objects and performs special handling for
potential keys named:
- ``label``: handles a single value (int or float) per object
- ``label_ids``: handles a list of values per object
Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
to the model. See glue and ner for example of how it's useful.
"""