title: Pytorch学习笔记-数据格式变换
学习笔记和实现代码详见如下:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
"""
ToTensor将PIL image或NumPy ndarray转换为FloatTensor。并将图像像素强度值缩放到[0.,1.)
Lambda转换应用任何用户定义的Lambda函数。
这里,我们定义了一个函数来将整数转换为一个单热编码张量。
它首先创建一个大小为10的零张量(数据集中标签的数量),并调用scatter_,它在标签y给出的索引上赋值为1。
"""