pytorch中常用的包
torchvision
torchvision包包含了目前流行的数据集,模型结构和常用的图片装换工具。
torchvision.datasets 包含数据集:
- MNIST
- COCO
- LSUN Classification
- ImageFolder
- ImageNet-12
- CIFAR10 and CIFAR100
- STL10
由于以上datasets都是torch.utils.data.Dataset的子类,可以通过torch.utils.data.DataLoader使用多线程。
例如:
torch.utils.data.DataLoader(coco_cap, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)
也可以将自己的数据集,调用DataLoader函数。具体步骤请参看multi_label_classifier.py及dataset_processing.py1
。
torchvision.transforms 对PIL.Image进行变换:
- torchvision.transforms.Compose(transforms): 将多个transforms组合起来
- torchvision.transforms.CenterCrop(size):将给定的PIL.Image进行中心切割,得到给定的size的图片。
- torchvision.transforms.RandomCrop(size):切割中心点的位置随机选取。
- torchvision.transforms.ToTensor():将取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,装换成shape为[C,H,W],取值范围是[0,1.0]的torch.FloadTensor。
- torchvision.transforms.ToPILImage():将shape为[C,H,W]的Tensor或shape为(H,W,C)装换成shape为[C,H,W],取值范围是[0,1.0]的torch.FloadTensor。