pytorch p29 torchvision的介绍

一 概述

本节课开始学习torchvision,老师 说不会吧每个点展开讲一遍,穿插在 例子中介绍。下面是官网介绍 :

torchvision — Torchvision 0.11.0 documentation

This library is part of the PyTorch project. PyTorch is an open source machine learning framework.

The torchvision package consists of popular datasets, model architectures, and common image transformations for computer vision.

    大意就是torchvision是  深度学习框架pytorch的组成部分,包含了计算机视觉领域的数据集、模型、通用图片转换等,下面展开看下各个模块的内容。

二 datasets

     torchvision.datasets是继承torch.utils.data.Dataset, they have __getitem__ and __len__ methods implemented. Hence, they can all be passed to a torch.utils.data.DataLoader which can load multiple samples in parallel using torch.multiprocessing workers.

我用上一节的例子

# 加载训练集
train_dataset = datasets.MNIST(
                                root='./data',
                                train=True,
                                transform=transforms.ToTensor(),
                                download=True
                                )
# 记载测试集
test_dataset = datasets.MNIST(root='./data',
                             train=False,
                             transform=transforms.ToTensor())
# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                         batch_size=batch_size,
                                         )

They all have two common arguments: transform and target_transform to transform the input and target respectively. 

  • transform:可以对数据进行的变换;
  • target_transform:可以对标签进行的变换。

上面的也用了 transforms.ToTensor()

torchvision.datasets包括以下内容: 

torchvision.models

torchvision.models包含下列模型的定义:

      The models subpackage contains definitions of models for addressing different tasks, including: image classification, pixelwise semantic segmentation, object detection, instance segmentation, person keypoint detection and video classification.

     有模型 还有权重、偏置参数:

You can construct a model with random weights by calling its constructor:

构建一个模型,随机初始化参数

import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.shufflenet_v2_x1_0()

We provide pre-trained models, using the PyTorch torch.utils.model_zoo. These can be constructed by passing pretrained=True:

使用预训练的模型进行参数初始化. 

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
  • torchvision.transforms

Most transformations accept both PIL images and tensor images, although some transformations are PIL-only and some are tensor-only. The Conversion Transforms may be used to convert to and from PIL images.

torchvision.transforms.functional 模块提供了一些常用的转换,这些转换都能够接受以下三种输入:

PIL Image:对于 RGB 图像,size 为 (W, H),将其转换为 NumPy array 后 size 为 (H, W, C);
Tensor Image:指具有 shape 为 (C, H, W) 的一个 tensor,C 为通道数,H、W 分别是图像的高和宽;
batch of Tensor Images:指具有 shape 为 (B, C, H, W) 的一个 tensor,B 为 batchsize,也就是一个批次中的图像数量。

里面功能很多,有对图片的旋转、裁剪、图片转换类等,后面的例子再整理。

安装可以使用 :pip3 install torchvision

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值