import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
def main():
batchsz = 32
cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor
]), download=True)
cifar_train = DataLoader(cifar_train,batch_size=batchse,shuffle=True)
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor
]), download=True)
cifar_teat = DataLoader(cifar_train,batch_size=batchse,shuffle=True)
x, label = iter(cifar_train).next()
print('x:', x.shape, 'label:', label.shape)
if __name__ == "__main__":
main()
pytorch下载CIFAR10数据集
最新推荐文章于 2024-07-25 11:48:56 发布