0.前言
框架:pytorch
1.直接从网上下载数据集,用于模型的测试
tsf = transforms.Compose([transforms.ToTensor()])
train_data = datasets.FashionMNIST(root='./data/fashionmnist/train',train=True,transform=tsf,download=True)
test_data = datasets.FashionMNIST(root='./data/fashionmnist/test',train=False,transform=tsf,download=True)
train_iter = DataLoader(train_data,batch_size=32,shuffle=True)
test_iter = DataLoader(test_data,batch_size=32,shuffle=False)
2.从本地文件夹中加载
tsf = transforms.Compose([transforms.Resize((28, 28)),
transforms.ToTensor()])
train_data = datasets.ImageFolder(root=r'image\dogcat\dogcat\train',transform=tsf)
train_iter = DataLoader(train_data,batch_size=32,shuffle=True)
本地文件夹的目录结构