前因
今天写代码过程中,要分别求得一个N x N
Tensor矩阵的主对角元素与非对角元素,看到一个巧妙的写法,故记录下来。
探索
torch.diagonal()
主对角元素很好得到,Pytorch有现成的API可调用,为torch.diagonal
,详情如下:
使用起来也很方便,示例如下:
x = torch.randn(4,4)
# tensor([[ 0.9148, 0.1396, -0.8974, 2.0014],
# [ 0.1129, -0.3656, 0.4371, 0.2618],
# [ 1.1049, -0.0774, -0.4160, -0.4922],
# [ 1.3197, -0.2022, -0.0031, -1.3811]])
torch.diagonal(x)
# tensor([ 0.9148, -0.3656, -0.4160, -1.3811])
妙用矩阵变换
关于非对角线元素,就没有特定的API了,找寻资料的过程中,看到一个比较巧妙的方法,直接上代码:
n, m = x.shape
assert n == m
x.flatten()[:-1].view(n-1,n+1)[:,1:].flatten()
# tensor([ 0.1396, -0.8974, 2.0014, 0.1129, 0.4371, 0.2618, 1.1049, -0.0774,
# -0.4922, 1.3197, -0.2022, -0.0031])
核心代码就在最后一行,下面主要分解一下来介绍。首先利用flatten()
拉直向量,然后去掉最后一个元素,得到
n
2
−
1
n^2-1
n2−1个元素,然后构造为一个维度为[N-1, N+1]
的矩阵。在这个矩阵中,之前所有的对角线元素全部出现在第1列,如下所示:
然后根据索引获取[:, 1:]
元素,得到的就是原矩阵的非对角线元素了。