数据并不总是以最终处理过的形式出现,而机器学习算法需要的是经过处理的数据。我们使用转换(transforms)对数据进行一些操作,使其适用于训练。
所有的TorchVision数据集都有两个参数-transform和target_transform,用于接受包含转换逻辑的可调用对象,transform用于修改特征,target_transform用于修改标签。torchvision.transforms模块提供了几个常用的转换。
FashionMNIST的特征以PIL图像格式呈现,标签是整数。为了训练,我们需要将特征转换为归一化的张量,将标签转换为独热编码张量。为了进行这些转换,我们使用ToTensor和Lambda。
以下是每行代码的注释:
import torch # 导入PyTorch库
from torchvision import datasets # 从torchvision导入FashionMNIST数据集
from torchvision.transforms import ToTensor, Lambda # 导入ToTensor和Lambda变换
ds = datasets.FashionMNIST( # 创建FashionMNIST数据集实例
root="data", # 数据集保存在"data"目录
train=True, # 是否使用训练集
download=True, # 是否自动下载数据集
transform=ToTensor(), # 将PIL图像转换为Torch张量格式
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1) # 将标签y转换为一个1x10的张量,其中标签y对应的位置为1,其他位置为0
)
)
在最后一行代码中,target_transform参数设置为一个Lambda函数,Lambda函数接受一个参数 y,它代表标签的原始值。Lambda函数中的逻辑执行了以下操作:
- torch.zeros(10, dtype=torch.float):创建一个值为0,数据类型为torch.float,形状为1x10的张量。
- scatter_(0, torch.tensor(y), value=1):将上一步创建的零张量中,第y个位置的元素用值为1进行替换,即进行独热编码的操作。这里的scatter_是一个张量操作方法,它接受三个参数:第一个参数表示沿着哪个维度执行操作(这里是0,表示按行操作),第二个参数是一个长度为1的1D张量,表示要更新的索引位置,第三个参数表示要赋予的值(这里是1)。
通过上述操作,target_transform将原始标签 y 转换为一个形状为1x10的独热编码张量,其中原始标签对应的位置为1,其他位置为0。这样可以确保标签在训练过程中的表示与模型期望的输出格式相匹配。
ToTensor()函数
ToTensor()函数将PIL图像或NumPy数组转换为FloatTensor,并将图像的像素强度值缩放到 [0., 1.] 的范围内。
Lambda变换
Lambda变换是应用用户定义的任意lambda函数。在这里,我们定义了一个函数将整数转换为独热编码的张量。该函数首先创建一个大小为10的零张量(数据集中标签的数量),然后调用scatter_方法,根据标签y的值在相应的索引位置上赋值为1。
简而言之,ToTensor()函数将图像进行张量化并缩放,而Lambda变换通过自定义函数将整数标签转换为独热编码张量。这样可以确保标签数据的格式与模型期望的输出格式相匹配。