(1). scatter_函数详细描述如下:
scatter_(input,dim,index,value)
将value对应的值按照index确定的索引写入input张量中,其中索引是根据给定的dim(维度)来确定的。
"""
Args:
input:要进行scatter_填充的tensor
dim:在input张量进行scatter_填充的维度
index:input对应dim的填充索引,要小于对应填充维度的长度,且index维度要与input张量维度一致
value:填充值
"""
(2). 代码实现
import torch
label = torch.zeros(2, 4)
print("label:",label)
label.scatter_(dim=1,index=torch.LongTensor([[2],[3]]),value=1)
print("new_label: ",label)
显示结果:
label: tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.]])
new_label: tensor([[0., 0., 1., 0.],
[0., 0., 0., 1.]])