Pytorch学习日记3:Transforms

Pytorch官网教程:https://pytorch.org/tutorials/   

主要内容:

1.总体介绍

2.ToTensor()

3.Lambda Transforms  详解scatter_函数

1.总体介绍

在学习日记2中,预加载的数据集或者自定义的数据集的形式一般不是机器学习算法训练所要求的最终形式,所以需要进行数据格式转换。torchvision.transforms提供了很多常用转换。

在torchvision的datasets中有两个参数:

transfrom:修改特征

target_transform:修改标签

举例:学习日记2中预加载的FashionMNIST特征是PIL图像格式。对于训练,需要将特征作为归一化张量,并将标签作为 one-hot 编码张量。为了进行这些转换,需要使用 ToTensor和 Lambda

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor,Lambda

ds=datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y:torch.zeros(10,dtype=torch.float).scatter_(0,torch.tensor(y),value=1))
   # torch.zeros(10,dtype=torch.float)  返回一个由0填充的张量,形状由size决定。
)

我们着重来看一下transform和target_transform,其余在学习日记2中已经注明。

2.ToTensor()

ToTensor()将PIL图像或NumPy ndarray转换为浮点张量。并在[0,1]范围内缩放图像的像素强度值。

3.Lambda Transforms  详解scatter_函数

(1)对于lambda函数,请自行了解。

(2)scatter_函数

target.scatter(dim, index, src)

target:即目标张量,将在该张量上进行映射

src:即源张量,将把该张量上的元素逐个映射到目标张量上

dim:指定轴方向,定义了填充方式。dim=0行填充,dim=1列填充

index: 按照轴方向,在target张量中需要填充的位置

很抽象,看这篇文章就懂了,写的很好:【Pytorch】scatter函数详解

target_transform=Lambda(lambda y:torch.zeros(10,dtype=torch.float).scatter_(0,torch.tensor(y),value=1))

这一行代码的意思就是用1填充由0填充的,大小为10的张量,行填充,填充位置由y确定。

注:pytorch中,一般函数加下划线代表直接在原来的Tensor上修改

scatter函数的一个典型应用就是在分类问题中,将目标标签转换为one-hot编码形式。

(3)one-hot编码,请自行了解。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

南风知我意95

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值