PyTorch学习之Transforms模块

以下内容全部来自 Transforms

Ⅰ. Transforms

简而言之,就是训练的数据有时候并不是机器学习训练的数据格式,这个时候就需要 Transforms 对数据进行一些操作(转换),使其适合做神经网络的输入。比如对于图像数据,通常是一个三维度 Tensor (长、宽、channels),但是神经网络通常需要一个拉直成一维的 Tensor,这个时候就需要用到 Transforms 对三维 Tensor 进行拉直。

所有的 TorchVision 数据集都有两个参数,transform 用于修改特征,target_transform 用于修改标签,tochvision.transforms 模块提供了几种开箱即用的转换。

FashionMNIST 的特征是 PIL Image 格式的图片,标签是整数,对于训练来说,特征需要拉直成一维的 Tensor,而标签需要作为-独热编码 Tensor,为了进行这些转换,使用 ToTensor 和 Lambda。

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root = "data",
    train = True, 
    download = True,
    transform = ToTensor(),
    # https://blog.csdn.net/xinjieyuan/article/details/106672340
    target_transform = Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

Out:

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gzExtracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gzExtracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Ⅱ. ToTensor()

ToTensor 将 PIL Image 或者 NumPy ndarray 转换为 FloatTensor,并将图像像素缩放到[0, 1]

Ⅲ. Lambda Transforms

Lambda 接受一个用户自定义的 lambda 函数,在这个实例中,定义了一个函数将整数转换为一个独热编码张量,首先创建一个大小为 10 的零张量(数据集中标签的数量),并且调用 scatter_,他在标签 y 给出的索引上指定 value=1

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值