1 torchvision的模块
1、torchvision.datasets
如:COCO 目标检测、语义分割;MNIST 手写文字;CIFAR 物体识别
2、torchvision.io
输入输出模块,不常用
3、torchvision.models
提供一些比较常见的神经网络,有的已经预训练好,如分类、语义分割、目标检测、视频分类
4、torchvision.ops
torchvision提供的一些比较少见的特殊的操作,不常用
5、torchvision.transforms
6、torchvision.utils
提供一些常用的小工具,如TensorBoard
2 torchvision.datasets和torchvision.transforms结合使用
以CIFAR数据集为例:CIFAR10 数据集包含了6万张32×32像素的彩色图片,图片有10个类别,每个类别有6千张图像,其中有5万张图像为训练图片,1万张为测试图片。
参数:
- root:数据集的位置
- train: 若true则创建的是训练集,false则创建的是测试集
- transform:对数据集中的所有数据进行变换
- target_transform
- download:若true则会自动从网上下载数据集
import torchvision
train_set = torchvision.datasets.CIFAR10(root="./datasets",train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root="./datasets",train=False,download=True)
数据集下载过慢时:获得下载链接后,把下载链接粘贴到迅雷中,会下载压缩文件tar.gz,download依然为True,运行后会自动解压该数据
import torchvision
train_set = torchvision.datasets.CIFAR10(root="./datasets",train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root="./datasets",train=False,download=True)
print(test_set[0]) # 查看测试集中的第一个数据,是一个元组:(img, target)
print(test_set.classes) # 列表
img,target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])
img.show()
与transforms结合使用:
import torchvision
from torch.utils.tensorboard import SummaryWriter
# 把dataset_transform运用到数据集中的每一张图片,都转为tensor数据类型
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="./datasets",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./datasets",train=False,transform=dataset_transform,download=True)
# print(test_set[0])
writer = SummaryWriter("log")
#显示测试数据集中的前10张图片
for i in range(10):
img,target = test_set[i]
writer.add_image("test_set",img,i)
writer.close()