Pytorch有很多方便易用的包,今天要谈的是torchvision包,它包括3个子包,分别是: torchvison.datasets ,torchvision.models ,torchvision.transforms ,分别是预定义好的数据集(比如MNIST、CIFAR10等)、预定义好的经典网络结构(比如AlexNet、VGG、ResNet等)和预定义好的数据增强方法(比如Resize、ToTensor等)。这些方法可以直接调用,简化我们建模的过程,也可以作为我们学习或构建新的模型的参考。
本文,我们讲述的是models,且只谈模型的加载。models这个包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用的网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型。
模型地址:https://github.com/pytorch/vision/tree/master/torchvision/models
官方文档:https://pytorch.org/docs/master/torchvision/models.html
我将加载的方法简单总结为以下四种:
1.直接加载预训练模型
1 importtorchvision.models as models2
3 resnet50 = models.resnet50(pretrained=True)
这样就导入了resnet50的预训练模型了。
如果只需要网络结构,不需要用预训练模型的参数来初始化,那么就是:
model =torchvision.models.resnet50(pretrained