pyTorch框架中的torchvision
在pyTorch框架中,torchvision是一个非常重要的包,其中包含三部分:torchvision.datasets、torchvision.models、torchvision.transforms。
详细介绍参考官网和源码
pyTorch中的预训练模型构建和下载
以vgg16为例:
预训练模型中默认输入RGB图像,h和w不低于224,图像的像素值在[0,1]之间,使用均值mean=[0.485, 0.456, 0.406]和方差std=[0.229, 0.224, 0.225]进行归一化。
所以,对图像应进行如下处理:
custom_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
预训练模型的加载
当pretrained=True时,导入预训练模型
当#pretrained=False时,只导入网络结构不导入参数。
import torchvision.models as models
model = models.vgg16(pretrained=True)#获取训练好的vgg16模型
运行结果如下: