Pytorch——计算机视觉工具包:torchvision
torchvision
独立于Pytorch,需通过pip install torchvision
安装。
torchvision 主要包含以下三部分:
- models : 提供深度学习中各种经典的网络结构以及训练好的模型,包括Alex Net, VGG系列、ResNet系列、Inception系列等;
- datasets:提供常用的数据集加载,设计上都是继承torch.utils.data.Dataset,主要包括MMIST、CIFAR10/100、ImageNet、COCO等;
- transforms: 提供常用的数据预处理操作,主要包括对Tensor及PIL Image对象的操作。
1.models
from torchvision import models
from torch import nn
#加载预训练好的模型,如果不存在会下载
resnet34=models.resnet34(pretrained=True, num_classes=1000)
resnet34.fc=nn.Linear(512,10)
2.datasets
from torchvision import datasets
from torchvision import transforms as T
transform=T.Compose([
T.Resize(224), #缩放图片(Image),保持长宽比不变,最短边为224像素
T.CenterCrop(224), #从图片中间裁剪出224*224的图片
T.ToTensor(), #将图片Image转换成Tensor,归一化至【0,1】
T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]) #标准化至【-1,1】,规定均值和方差
])
dataset=datasets.MNIST('data/',download=True, train=False, transform=transform)
len(dataset)
3.transforms
import torch as t
from torchvision import transforms
to_pil=transforms.ToPILImage()
to_pil(t.randn(3,64,64))
输出的是随机噪声:
4.make_grid、save_img
torchvision 还提供了两个常用的函数,一个是make_grid
,它能将多张图片拼接在一个网格中;