前言
PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:torchvision.datasets、torchvision.models、torchvision.transforms。这3个子包的具体介绍可以参考官网:http://pytorch.org/docs/master/torchvision/index.html。具体代码可以参考github:https://github.com/pytorch/vision/tree/master/torchvision。
这篇博客介绍torchvision.models。torchvision.models这个包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用的网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型。
1. 都有哪些模型?
PyTorch定义了几个常用模型,并且提供了预训练版本:
- AlexNet: AlexNet variant from the “One weird trick” paper.
- VGG: VGG-11, VGG-13, VGG-16, VGG-19 (with and without batch normalization)
- ResNet: ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152
- SqueezeNet: SqueezeNet 1.0, and SqueezeNet 1.1
2. 如何构建和下载?
预训练模型可以通过设置pretrained=True来构建:
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
预训练模型期望的输入是RGB图像的mini-batch:(batch_size, 3, H, W),并且H和W不能低于224。图像的像素值必须在范围[0,1]间,并且用均值mean=[0.485, 0.456, 0.406]和方差std=[0.229, 0.224, 0.225]进行归一化。
如果只需要网络结构,不需要用与训练模型的参数来初始化,可以将pretrained = False
model = torchvision.models.densenet169(pretrained=False)
# 等价于:
model = torchvision.models.densenet169()
举例子:
import torchvision.models as models
vgg16 = models.vgg16(pretrained = True) # 获取训练好的VGG16模型
pretrained_dict = vgg16.state_dict() # 返回包含模块所有状态的字典,包括参数和缓存
运行上面的代码,开始下载vgg16模型。