PyTorch官网PyTorch
Docs→torchvision→数据集
1、使用标准数据集 Dataset
import torchvision
train_set = torchvision.datasets.CIFAR10(
root = "./Dataset", # 数据集存储的位置
train = True, # True--训练集 False--测试集
transform: Optional[Callable] = None, # transform操作
target_transform: Optional[Callable] = None, # 对目标transform操作
download = True # true---下载 已经下载就不会下载了
)
# 查看数据集
print(test_set[0]) # PIL类型
# (<PIL> , 3) img, target
2、Dataloader
从dataset中取数据来加载到网络中去使用。
torch.utils.data.DataLoader(
dataset, # 前面介绍的dataset实例化
batch_size=1, # 每个batch的大小
shuffle=False, # True--打乱数据 默认False
sampler=None, #
batch_sampler=None,
num_workers=0, # 采用进程数量 默认0(一个主进程)
collate_fn=None, #
pin_memory=False,
drop_last=False, # 总数量/bantch size 除不尽时 余数是舍弃True还是不舍弃False
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor=2,
persistent_workers=False
)
实例,使用测试数据集
import torchvision
# 准备的测试集
from torch.utils.data import DataLoader
test_data = torchvision.datasets.CIFAR10("./Dataset", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
- batch size=4
每次取4个数据打包,比较一个数据与每个batch中数据的区别
# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)
# 一个patch中的数据
for data in test_loader:
imgs, targets = data
print(imgs.shape)
print(targets)
- drop_last=False(上面为false,下面为true)
- shuffle=False
每个epoch处理的图片顺序相同,当设置为true就不一样了。