Pytorch 提供
torchvision.models
接口,里面包含了一些常用用的网络结构,并提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型。
官方文档地址:https://pytorch.org/docs/master/torchvision/models.html#
一、PyTorch 官方提供的网络
1、对于分类问题的网络
- AlexNet
- VGG
- ResNet
- SqueezeNet
- DenseNet
- Inception v3
- GoogLeNet
- ShuffleNet v2
- MobileNet v2
- ResNeXt
- Wide ResNet
- MNASNet
2、对于语义分割问题的网络
- Fully Convolutional Networks
- DeepLabV3
3、对于目标检测、图像分割、特征点检测
- Faster R-CNN
- Mask R-CNN
- Keypoint R-CNN
4、对于视频分类
- ResNet 3D
- ResNet Mixed Convolution
- ResNet (2+1)D
二、模型的导入
- 导入
resnet50
并 使用 预训练模型
import torchvision
model = torchvision.models.resnet50(pretrained=True)
- 导入
resnet50
不使用 预训练模型
import torchvision
model = torchvision.models.resnet50(pretrained=False)
运行model = torchvision.models.resnet50(pretrained=True)
的时候,是通过models
包下的resnet.py
脚本进行的,源码如下:
../vision/torchvision/models/resnet.py
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}