dataset和dataloader的使用
1.dataset的使用
Pytorch提供多种数据集,要下载的话只需进入Pytorch官网,点击Docs下的torchvision,进入之后可以看到下方有多种常用数据集,如COCO、MNIST等,点击想要下载的数据集,进入后会有语句告知如何下载。
例如要下载CIFAR数据集,则通过
torchvision.datasets.CIFAR10(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)
下载。第一个参数root为要下载到本地的地址;第二个参数如果为train=True则下载作为train数据集,如果为False则test数据集;中间参数可以不填;最后一个参数如果为download=True,则代表下载到本地。代码如下:
import torchvision
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
print