pytorch笔记1-scatter_用法
1.源自:pytorch教程-transforms
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data", # 设置数据集的本地路径
train=True, # 数据为训练集
download=True, # 当本地路径中没有数据集时下载数据集
transform=ToTensor(), # 设置转换函数
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1)) # 将传入的参数y转换成one_hot编码
)
torchvision.transforms
中的 Lambda
是对用户自定义转换函数的封装,适用于提供的转换函数不符合自己使用要求的情况。
scatter_
的两个接口 : scatter_(dim,index,value)
和 scatter_(dim,index,src)
,对其直接理解即 value
按 dim
、 index
的规则索引替换调用者自身的值。
import torch
# 第一种形式
# 一维:当dim=0时,self[index[i]] = src[i]
torch.zeros(10).scatter_(0,index=torch.tensor([1,2,3]),src=torch.tensor([2.,3.,4.]))
# 运行结果:tensor([0.,2.,3.,4.,0.,0.,0.,0.,0.,0.])
'''二维:
self[index[i]][j] = src[i][j], # if dim == 0
self[i][index[j]] = src[i][j], # if dim == 1'''
x = torch.rand(2, 5)
# tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
# [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])
torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), src=x)
# 运行结果:tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
# [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
# [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])
'''三维:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2'''
# 举例同理按公式索引替换
# 第二种形式
# 举例
target_transform = Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
target_transform(5)
# 运行结果:tensor([0.,0.,0.,0.,1.,0.,0.,0.,0.,0.]),应用场景:将label转换为one-hot编码