torchversion的使用小技巧

预训练模型期望的输入:

  • RGB图像的mini-batch:(batch_size, 3, H, W),并且H和W不能低于224。
  • 图像的像素值必须在范围[0,1]间,并且用均值mean=[0.485, 0.456, 0.406]和方差std=[0.229, 0.224, 0.225]进行归一化。

torchvision主要处理图像数据,包含一些常用的数据集、模型、转换函数等。torchvision独立于PyTorch,需要专门安装。
torchvision主要包含以下四部分:

  • torchvision.models: 提供深度学习中各种经典的网络结构、预训练好的模型,如:Alex-Net、VGG、ResNet、Inception等。
  • torchvision.datasets:提供常用的数据集,设计上继承 torch.utils.data.Dataset,主要包括:MNIST、CIFAR10/100、ImageNet、COCO等。
  • torchvision.transforms:提供常用的数据预处理操作,主要包括对Tensor及PIL Image对象的操作。
  • torchvision.utils:工具类,如保存张量作为图像到磁盘,给一个小批量创建一个图像网格。

1. 加载model

(1) 加载几个预训练模型
通过pretrined=True可以加载预训练模型。pretrained默认值是False,不赋值和赋值False效果一样。

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
densenet = models.densenet_161()

只加载模型,不加载预训练参数 如果只需要网络结构,不需要训练模型的参数来初始化,可以将pretrained = False

保存和加载整个模型

torch.save(model_object, 'resnet.pth')
model = torch.load('resnet.pth')
# 导入模型结构
resnet18 = models.resnet18(pretrained=False)
# 加载预先下载好的预训练参数到resnet18
resnet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))
  • 12
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值