文章目录
torch.Tensor.scatter_(dim, index, src)
scatter_(dim, index, src) → Tensor
官方解释:
Writes all values from the tensor src
into self
at the indices specified in the index
tensor. For each value in src
, its output index is specified by its index in src
for dimension != dim
and by the corresponding value in index
for dimension = dim
.
不用看了,看了也不懂。。。
scatter_()操作可以这么理解:
假设
x
是一个tensor
,x.scatter_(dim, index, src)
表示对x
中的元素按某种规则进行修改(把某些来自参数src
的元素放置进来替换掉原来位置的元素,没被替换的则保持原样), 而修改的依据则由参数dim
和参数index
决定。
scatter_()的三个输入参数:
- src – 需要放置的元素就在这里面选取,可以是张量或标量
- dim – 沿着哪个维度进行选取(index操作)
- index – 需要选取的元素对应的索引
还是略抽象,看个简单的例子吧:
x = torch.zeros(3,3)
>>> x
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
src = torch.ones(2,3)
>>> src
tensor([[1., 1., 1.],
[1., 1., 1.]])
index = torch.tensor([[0,0,1],[1,0,0]])
>>> index
tensor([[0, 0, 1],
[1, 0, 0]])
dim
先取0,操作结果:
x = x.scatter_(0, index, src) # dim=0
>>> x
tensor([[1., 1., 1.],
[1., 0., 1.],
[0., 0., 0.]])
以上操作可以简化为:
torch.zeros(3,3).scatter_(0, torch.tensor([[0,0,1],[1,0,0]]), torch.ones(2,3))
可以看到,scatter_()操作将src
中的某些元素 1 放置到了x
中原本是 0 的位置,而放置的依据是dim
和index
两个参数,具体分析如下:
x
是二维的,可以用x[i][j]
来表示x
中的相应位置的元素,比如:
x = torch.zeros(3,3)
>>> x[0][0]
tensor(0.)
-
而
dim = 0
代表着沿着第0个维度进行操作,也就是对x[i][j]
中的i
进行相应的操作,替换规则如下:
x [ i n d e x [ i ] [ j ] ] [ j ] = s r c [ i ] [ j ] (1) x[index[i][j]][j] = src[i][j] \tag{1} x[index[i][j]][j]=src[i][j](1)其中
i
为0、1,j
为 0、1、2(这儿i
,j
的取值不要超过索引范围就ok啦)。当
i=0
,j=0,1,2
:
x[index[0][0]][0]
,index[0][0] = 0
⇒x[0][0]
=src[0][0]
= 1
x[index[0][1]][1]
,index[0][1] = 0
⇒x[0][1]
=src[0][1]
= 1
x[index[0][2]][2]
,index[0][2] = 1
⇒x[1][2]
=src[0][2]
= 1当
i=1
,j=0,1,2
:
x[index[1][0]][0]
,index[1][0] = 1
⇒x[1][0]
=src[1][0]
= 1
x[index[1][1]][1]
,index[1][1] = 0
⇒x[0][1]
=src[1][1]
= 1
x[index[1][2]][2]
,index[1][2] = 0
⇒x[0][2]
=src[1][2]
= 1所以上述操作就是将
x
中的某些元素修改为src
中对应位置的元素,而对应关系是由dim
,index
来决定 的。注意上述操作中,i=0
的第二项和i=1
的第二项实际上是都是对x[0][1]
进行修改,只不过修改之后的值恰 好是相同的。 -
那么
dim = 1
就代表着沿着第1个维度进行操作,也就是对x[i][j]
中的j
进行相应的操作,规则如下:
x [ i ] [ i n d e x [ i ] [ j ] ] = s r c [ i ] [ j ] (2) x[i][index[i][j]] = src[i][j] \tag{2} x[i][index[i][j]]=src[i][j](2)
具体就不展开细说了。
下面还是来看一下官方例子吧:
src = torch.rand(2, 5)
>>> src
tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004],
[ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]])
x = torch.zeros(3, 5)
x.scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), src)
tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004],
[ 0.0000, 0.2908, 0.0000, 0.4152, 0.0000],
[ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]])
随便测试一下:
x[index[0][0]][0]
, index[0][0] = 0
⇒ x[0][0]
= src[0][0]
= 0.3992
x[index[0][2]][2]
, index[0][2] = 2
⇒ x[2][2]
= src[0][2]
= 0.9044
是不是都对应上啦 (✿◡‿◡)
除了二维的,三维的也是同样的道理,官网也给出的三维的对照规则:
if dim
= 0,
s
e
l
f
[
i
n
d
e
x
[
i
]
[
j
]
[
k
]
]
[
j
]
[
k
]
=
s
r
c
[
i
]
[
j
]
[
k
]
self[index[i][j][k]][j][k] = src[i][j][k]
self[index[i][j][k]][j][k]=src[i][j][k]
if dim
= 1,
s
e
l
f
[
i
]
[
i
n
d
e
x
[
i
]
[
j
]
[
k
]
]
[
k
]
=
s
r
c
[
i
]
[
j
]
[
k
]
self[i][index[i][j][k]][k] = src[i][j][k]
self[i][index[i][j][k]][k]=src[i][j][k]
if dim
= 2,
s
e
l
f
[
i
]
[
j
]
[
i
n
d
e
x
[
i
]
[
j
]
[
k
]
]
=
s
r
c
[
i
]
[
j
]
[
k
]
self[i][j][index[i][j][k]] = src[i][j][k]
self[i][j][index[i][j][k]]=src[i][j][k]
好了,前面我们说到了src
除了是张量之外,还可以是标量,比如之前的例子:
torch.zeros(3,3).scatter_(0, torch.tensor([[0,0,1],[1,0,0]]), torch.ones(2,3))
修改为:
torch.zeros(3,3).scatter_(0, torch.tensor([[0,0,1],[1,0,0]]), 1)
也是同样的效果。
scatter_()操作常常用来对标签进行 one-hot 编码,这也是我最开始碰到这个函数的地方。one-hot 编码使用的src
一般就为标量,也就是利用一个标量对张量进行修改。
我们在计算 loss(比如Cross entropy loss)的时候,预测矩阵的形状一般为 (batch_size, num_classes),而标签的形状一般为 (batch_size, 1),需要将标签形状修改为(batch_size, num_classes)以方便计算,其实就是将标量扩展成张量。一般这么进行:
num_classes = 5
batch_size = 4
label = torch.randint(0, num_classes, size=(batch_size,1))
>>> label
tensor([[0],
[1],
[4],
[3]])
原始4个样本的标签分别是0、1、4、3,代表样本分别属于第一、二、五、四个类别,此处回顾一下交叉熵的计算方式:
l o s s = − ∑ i y i ln y ^ i (3) loss=-\sum_{i} y_{i} \ln \hat{y}_{i} \tag{3} loss=−i∑yilny^i(3)
次数的
y
i
y_{i}
yi 和
y
^
i
\hat{y}_{i}
y^i 具有相同的形状。比如上述的label
中第一个样本的标签是0
,把它转为长为num_classes
张量即为:
# 第一个类别处为1, 其余全为0
tensor([1, 0, 0, 0, 0])
one-hot 编码其实就是这样的操作,所属标签类别处为1,其余为0,用scatter_()函数实现如下:
# dim = 1
one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1)
>>> one_hot
tensor([[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 0., 0., 1.],
[0., 0., 0., 1., 0.]])
注意到编码后的one-hot
张量对应每个样本的标签处都实现了相应的转换。
其实在PyTorch 1.1版本之后,提供了更简单的 one-hot 编码方式:
https://stackoverflow.com/questions/44461772/creating-one-hot-vector-from-indices-given-as-a-tensor
one_hot = torch.nn.functional.one_hot(label, 5)
>>> one_hot
tensor([[[1, 0, 0, 0, 0]],
[[0, 1, 0, 0, 0]],
[[0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 0]]])
不过要注意生成的张量形状略有不同(此处为4x1x5,之前为4x5)。如果需要变换为4x5:
one_hot.squeeze(1)
即可!!!
torch.gather(input, dim, index, out=None, sparse_grad=False)
torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
理解了上面的scatter_()
,再来理解gather()
就很简单了,因为他们两个互为逆向操作,对于torch.gather(),三维操作规则如下:
out[i][j][k] = input[index[i][j][k]][j][k]
# if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]
# if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]
# if dim == 2
是不是看起来就很熟悉了,只是satter_()操作的三维规则逆向过来。同理二维的操作规则:
out[i][j] = input[index[i][j]][j]
# if dim == 0
out[i][j] = input[i][index[i][j]]
# if dim == 1
例子:
input = torch.tensor([[1,2],[3,4]])
output = torch.gather(input, 0, torch.tensor([[0,0],[1,0]]))
>>> output
tensor([[1, 2],
[3, 2]])
取两个测试一下:
output[0][0]
= input[index[0][0]][0]
= input[0][0]
= 1
output[1][1]
= input[index[1][1]][1]
= input[0][1]
= 2
完。