方法一
数据集中数据按照类别,进行文件放置。
import torch
import torchvision
import torchvision.transforms as transforms
# 读取数据集
datasets = torchvision.datasets.ImageFolder(root="数据集路径",
transform=transforms.Compose([
"数据预处理语句"比如
transforms.ToTensor()
]))
# 查看数据
print(datasets.imgs)
# 查看标签
print(datasets.classes)
# 批量加载数据集
dataloader = torch.utils.data.DataLoader(datasets,
shuffle = True,
batch_size="batch_size"
)
方法2
创建Dataset类来读取文件
import torch.utils.data as Dataset
class MyDataset(Dataset):
# 初始化过程
def __init__(self, .......):
# stuff
# 返回数据和标签
def __getitem__(self, index):
# stuff
return imgs, labels
# 返回所有数据的数目
def __len__(self):
return count
方法3
数据集类别文件存在嵌套
可采用os.listdir方式