Pytorch官网教程:https://pytorch.org/tutorials/
主要内容:
1.总体介绍
2.ToTensor()
3.Lambda Transforms 详解scatter_函数
1.总体介绍
在学习日记2中,预加载的数据集或者自定义的数据集的形式一般不是机器学习算法训练所要求的最终形式,所以需要进行数据格式转换。torchvision.transforms提供了很多常用转换。
在torchvision的datasets中有两个参数:
transfrom:修改特征
target_transform:修改标签
举例:学习日记2中预加载的FashionMNIST特征是PIL图像格式。对于训练,需要将特征作为归一化张量,并将标签作为 one-hot 编码张量。为了进行这些转换,需要使用 ToTensor
和 Lambda
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor,Lambda
ds=datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y:torch.zeros(10,dtype=torch.float).scatter_(0,torch.tensor(y),value=1))
# torch.zeros(10,dtype=torch.float) 返回一个由0填充的张量,形状由size决定。
)
我们着重来看一下transform和target_transform,其余在学习日记2中已经注明。
2.ToTensor()
ToTensor()将PIL图像或NumPy ndarray转换为浮点张量。并在[0,1]范围内缩放图像的像素强度值。
3.Lambda Transforms 详解scatter_函数
(1)对于lambda函数,请自行了解。
(2)scatter_函数
target.scatter(dim, index, src)
target
:即目标张量,将在该张量上进行映射
src
:即源张量,将把该张量上的元素逐个映射到目标张量上
dim
:指定轴方向,定义了填充方式。dim=0行填充,dim=1列填充
index
: 按照轴方向,在target
张量中需要填充的位置
很抽象,看这篇文章就懂了,写的很好:【Pytorch】scatter函数详解
target_transform=Lambda(lambda y:torch.zeros(10,dtype=torch.float).scatter_(0,torch.tensor(y),value=1))
这一行代码的意思就是用1填充由0填充的,大小为10的张量,行填充,填充位置由y确定。
注:pytorch中,一般函数加下划线代表直接在原来的Tensor上修改。
scatter函数的一个典型应用就是在分类问题中,将目标标签转换为one-hot编码形式。
(3)one-hot编码,请自行了解。