pytorch学习记录(一)

本文深入探讨了PyTorch中的scatter_函数,包括其两种接口的使用方法。通过示例展示了如何在一维和二维张量中进行索引替换操作,并解释了其在one-hot编码转换中的应用。此外,还提供了Lambda包装自定义转换的实例,用于将标签转换为one-hot编码。
摘要由CSDN通过智能技术生成

pytorch笔记1-scatter_用法

1.源自:pytorch教程-transforms

import torch	
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",				  # 设置数据集的本地路径
    train=True,					  # 数据为训练集
    download=True,			 	# 当本地路径中没有数据集时下载数据集
    transform=ToTensor(),		   # 设置转换函数
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))		 				# 将传入的参数y转换成one_hot编码
)

torchvision.transforms 中的 Lambda 是对用户自定义转换函数的封装,适用于提供的转换函数不符合自己使用要求的情况。

scatter_ 的两个接口 : scatter_(dim,index,value)scatter_(dim,index,src) ,对其直接理解即 valuedimindex 的规则索引替换调用者自身的值。

import torch
# 第一种形式
# 一维:当dim=0时,self[index[i]] = src[i]
torch.zeros(10).scatter_(0,index=torch.tensor([1,2,3]),src=torch.tensor([2.,3.,4.]))
# 运行结果:tensor([0.,2.,3.,4.,0.,0.,0.,0.,0.,0.])

'''二维:
self[index[i]][j] = src[i][j], # if dim == 0
self[i][index[j]] = src[i][j], # if dim == 1'''
x = torch.rand(2, 5)
# tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
#        [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])
torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), src=x)
# 运行结果:tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
#        	[0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
#        	[0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])

'''三维:
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2'''
# 举例同理按公式索引替换

# 第二种形式
# 举例
target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
target_transform(5)
# 运行结果:tensor([0.,0.,0.,0.,1.,0.,0.,0.,0.,0.]),应用场景:将label转换为one-hot编码

参考链接

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值