计算机视觉是深度学习中最重要的一类应用,为了方便研究者应用,pytorch专门开发了一个视觉工具包torchvision。
可通过pip install torchvision安装。
torchvision主要包含以下三部分:
模型加载
- models:提供深度学习中各种经典网络结构及与训练好的模型,包括Alex-Net、VGG系列、ResNet系列、Inception系列等。
- datasets:提供常用的数据集下载,设计上都是继承torch.utils.data.Dataset,主要包括MNIST、CIFAR10/100、ImageNet、COCO等。
- transform:提供常用的数据预处理操作,主要包括对Tensor及PIL Image对象的操作
from torchvision import models
from torch import nn
#加载预训练模型,如果不存在会下载
#预训练的模型保存在~/.torch/models/下面
resnet34 = models.resnet34(pretrained=True,num_classes=1000)
#修改最后的全连接层为10分类问题(默认是ImageNet上的1000分类)
resnet34.fc = nn.Linear(5