1. 背景
可以先看torch官方文档介绍:
主要作用是根据索引值index,向tensor中指定dim维度的index位置写入scr所对应的数值,可以用来生成one-hot向量和特定mask,熟练使用该函数,就不用暴力for循环啦。
2. 函数应用
(1)one-hot向量生成
将图像的标签值转换为一组只有0和1组成的向量,这是DL领域常用到的。
import torch
targets=torch.zeros(3,5)
index = torch.LongTensor([[3],[2],[5]])
targets.scatter_(1,index-1,1)
print(targets.size(),index.size())
print(targets)
要注意维度:
(2)多标签或mask的one-hot向量生成
import torch
targets=torch.zeros(3,5)
index1 = torch.LongTensor([[3],[2],[5]])
index2 = torch.LongTensor([[1],[2],[4]])
targets.scatter_(1,index1-1,1)
targets.scatter_(1,index2-1,1)
print(targets)
结果:
(3)根据位置插入指定值
import torch
targets=torch.zeros(3,5)
scr1=torch.Tensor([[0.1],[0.2],[0.3]])
scr2=torch.Tensor([[0.6],[0.5],[0.4]])
index1 = torch.LongTensor([[3],[2],[5]])
index2 = torch.LongTensor([[1],[2],[4]])
targets.scatter_(1,index1-1,scr1)
targets.scatter_(1,index2-1,scr2)
print(targets)
结果: