调用torchvision.models
torchvision地址
torchvision有一些可以使用的模型可以直接导入
PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:
- torchvision.datasets
- torchvision.models
- torchvision.transforms
如何构建与下载
可以通过设置pretrained=True来构建, pretrained 是在ImageNet上训练的
“”"
from torchvision import models
vgg16 = models.vgg16(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
如果只是需要模型pretrained默认是False
from torchvision import models
vgg16 = models.vgg16()
特别的如果预训练参数已经下载,可以直接调用参数之前下载的位置直接导入预训练的数据
import torch
from torchvision import models
pretrained_model = "地址"
vgg16 = models.vgg16(pretrained=True)
vgg16.load_state_dict(torch.load(pretrained_model, map_location='cpu'))
- 参考
https://blog.csdn.net/u014380165/article/details/79119664
https://github.com/pytorch/vision