【深度学习笔记】pytorch框架在深度学习中常用模块之torchvision


前言

pytorch框架有很多功能模块,对于初学者来讲,可以先从基础常用的模块开始学习,这样可以快速熟悉Pytorch在实际项目上的应用。本文讲总结介绍在深度学习项目中,使用频率较多的模块之一torchvision,供大家参考学习。

torchvision属于torch的图形库,在pytorch项目中使用,它包含以下4个子模块,是我们写训练代码时经常用到的

提示:以下是本篇文章正文内容,下面案例可供参考

一、torchvision.datasets加载使用数据集

用法1:使用官方自带数据集并加载

torchvision中datasets中所有封装的数据集(MNIST、 COCO、CIFAR10、CIFAR100等)都是torch.utils.data.Dataset的子类,它们都实现了__getitem__和__len__方法。因此,它们都可以用torch.utils.data.DataLoader进行数据加载。示例代码如下:

torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
root (string): 表示数据集的根目录,其中根目录存在CIFAR10/processed/training.pt和CIFAR10/processed/test.pt的子目录
train (bool, optional): 如果为True,则从training.pt创建数据集,否则从test.pt创建数据集
download (bool, optional): 如果为True,则从internet下载数据集并将其放入根目录。如果数据集已下载,则不会再次下载
import torchvision
# 准备的测试数据集
from torch.utils.data import DataLoader
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# 测试数据集中第一张图片及target
img, target = test_data[0]
print(target)
print(img.shape)

用法2:使用自己的数据集(使用ImageFolder通用的数据加载器)

#数据集组织方式如下:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
torchvision.datasets.ImageFolder(root="root folder path", [transform, target_transform])
transform (callable, optional): 接收PIL图片并返回转换后版本图片的转换函数
target_transform (callable, optional): 接收PIL接收目标并对其进行变换的转换函数

二、torchvision.models 经典模型的集合,可以直接使用,无需自己在手动搭建

1.torchvision.models这个包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用的网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型。

代码如下(示例):

#导入resnet50网络及加载预训练模型
import torchvision
model = torchvision.models.resnet50(pretrained=True)

三、torchvision.transforms使用

torchvision.transforms用于对图像预处理操作,旋转、裁剪、数据转换等,训练前用来做数据增强

调用方法如下:
transforms.Resize(256):将图片的(h,w)均转换为(256,256)
transforms.CenterCrop(450):从中心位置将图像裁剪为(450, 450)
transforms.ToTensor():可以将图像(H x W x C)从灰度范围从0-255变换到形状为 (C x H x W)的 0-1之间,并转成tensor数据类型
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]):把图像灰度范围从把(0,1)变换到(-1,1);对每个通道而言,Normalize执行:
经常与 Compose一起使用,把多个变换 组合一起,如下示例:
代码如下(示例):

data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
    # 实例化训练数据集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])
    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])

四、torchvision.utils使用

调用方法如下:

torchvision.utils.save_image 保存图片
torchvision.utils.make_grid 创建网格,把图片排成网格形式

torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)

总结

本文主要记录介绍torchvision子模块的作用和用法,这些知识点在深度学习模型训练中非常有用,属于基础知识。博主后续会继续更新分享深度学习笔记,记录提炼知识点,总结学习经验及项目经验。如果本文对您的理解有帮助,请点赞+关注+收藏!

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

极客程序设计

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值