Pytorch:快速求得NxN矩阵的主对角线(diagonal)元素与非对角线元素

前因

今天写代码过程中,要分别求得一个N x NTensor矩阵的主对角元素与非对角元素,看到一个巧妙的写法,故记录下来。

探索

torch.diagonal()

主对角元素很好得到,Pytorch有现成的API可调用,为torch.diagonal,详情如下:
在这里插入图片描述
使用起来也很方便,示例如下:

x = torch.randn(4,4)
# tensor([[ 0.9148,  0.1396, -0.8974,  2.0014],
#        [ 0.1129, -0.3656,  0.4371,  0.2618],
#        [ 1.1049, -0.0774, -0.4160, -0.4922],
#        [ 1.3197, -0.2022, -0.0031, -1.3811]])

torch.diagonal(x)
# tensor([ 0.9148, -0.3656, -0.4160, -1.3811])

妙用矩阵变换

关于非对角线元素,就没有特定的API了,找寻资料的过程中,看到一个比较巧妙的方法,直接上代码:

主要参考依据:https://github1s.com/facebookresearch/barlowtwins/blob/e6f34a01c0cde6f05da6f431ef8a577b42e94e71/main.py#L207

n, m = x.shape
assert n == m
x.flatten()[:-1].view(n-1,n+1)[:,1:].flatten()
# tensor([ 0.1396, -0.8974,  2.0014,  0.1129,  0.4371,  0.2618,  1.1049, -0.0774,
#        -0.4922,  1.3197, -0.2022, -0.0031])

核心代码就在最后一行,下面主要分解一下来介绍。首先利用flatten()拉直向量,然后去掉最后一个元素,得到 n 2 − 1 n^2-1 n21个元素,然后构造为一个维度为[N-1, N+1]的矩阵。在这个矩阵中,之前所有的对角线元素全部出现在第1列,如下所示:
在这里插入图片描述
然后根据索引获取[:, 1:]元素,得到的就是原矩阵的非对角线元素了。

  • 23
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值