torch中datasets.load_dataset用法

转发:https://blog.csdn.net/weixin_49346755/article/details/125284869

函数原型
datasets.load_dataset(
	path: str,
    name: Optional[str] = None,
    data_dir: Optional[str] = None,
    data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None,
    split: Optional[Union[str, Split]] = None,
    cache_dir: Optional[str] = None,
    features: Optional[Features] = None,
    download_config: Optional[DownloadConfig] = None,
    download_mode: Optional[DownloadMode] = None,
    ignore_verifications: bool = False,
    keep_in_memory: Optional[bool] = None,
    save_infos: bool = False,
    revision: Optional[Union[str, Version]] = None,
    use_auth_token: Optional[Union[bool, str]] = None,
    task: Optional[Union[str, TaskTemplate]] = None,
    streaming: bool = False,
    **config_kwargs
    )
函数说明

load_dataset函数从Hugging Face Hub或者本地数据集文件中加载一个数据集。可以通过 https://huggingface.co/datasets 或者datasets.list_datasets()函数来获取所有可用的数据集。

参数path表示数据集的名字或者路径。可以是一个数据集的名字,比如"imdb"、“glue”;也可以是通用的产生数据集文件的脚本,比如"json"、“csv”、“parquet”、“text”;或者是在数据集目录中的脚本(.py)文件,比如“glue/glue.py”。

参数name表示数据集中的子数据集,当一个数据集包含多个数据集时,就需要这个参数。比如"glue"数据集下就包含"sst2"、“cola”、"qqp"等多个子数据集,此时就需要指定name来表示加载哪一个子数据集。

参数data_dir表示数据集所在的目录,参数data_files表示本地数据集文件。

参数split如果为None,则返回一个DataDict对象,包含多个DataSet数据集对象;如果给定的话,则返回单个DataSet对象。

参数cache_dir表示缓存数据的目录,默认为"~/.cache/huggingface/datasets"。参数keep_in_memory表示是否将数据集缓存在内存中,加载一次后,再次加载可以提高加载速度。

参数revision表示加载数据集的脚本的版本。

函数使用

1、加载imdb数据集

>>> dataset = datasets.load_dataset("imdb")
>>> dataset
DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})

2、加载glue下的cola子数据集

>>> dataset = datasets.load_dataset("glue", name="cola")
>>> dataset
DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 8551
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1043
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1063
    })
})

3、通过csv脚本加载本地的test.tsv文件中的数据集

>>> dataset = datasets.load_dataset("csv", data_dir="E:\Python\\transfomers\\test", data_files="test.tsv")
>>> dataset
DatasetDict({
    train: Dataset({
        features: ['14'],
        num_rows: 4
    })
})

4、通过glue.py脚本文件加载cola数据集

>>> dataset_1 = datasets.load_dataset("../dataset/glue/glue.py", name="cola")
  • 5
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
class Dn_datasets(Dataset): def __init__(self, data_root, data_dict, transform, load_all=False, to_gray=False, s_factor=1, repeat_crop=1): self.data_root = data_root self.transform = transform self.load_all = load_all self.to_gray = to_gray self.repeat_crop = repeat_crop if self.load_all is False: self.data_dict = data_dict else: self.data_dict = [] for sample_info in data_dict: sample_data = Image.open('/'.join((self.data_root, sample_info['path']))).copy() if sample_data.mode in ['RGBA']: sample_data = sample_data.convert('RGB') width = sample_info['width'] height = sample_info['height'] sample = { 'data': sample_data, 'width': width, 'height': height } self.data_dict.append(sample) def __len__(self): return len(self.data_dict) def __getitem__(self, idx): sample_info = self.data_dict[idx] if self.load_all is False: sample_data = Image.open('/'.join((self.data_root, sample_info['path']))) if sample_data.mode in ['RGBA']: sample_data = sample_data.convert('RGB') else: sample_data = sample_info['data'] if self.to_gray: sample_data = sample_data.convert('L') # crop (w_start, h_start, w_end, h_end) image = sample_data target = sample_data sample = {'image': image, 'target': target} if self.repeat_crop != 1: image_stacks = [] target_stacks = [] for i in range(self.repeat_crop): sample_patch = self.transform(sample) image_stacks.append(sample_patch['image']) target_stacks.append(sample_patch['target']) return torch.stack(image_stacks), torch.stack(target_stacks) else: sample = self.transform(sample) return sample['image'], sample['target']
06-01
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值