torch.diagonal()

torch.diagonal()函数规则

定义:返回一个局部视图(类似于数据库里面的视图,但是属实没太弄明白返回的规则)

引用一下官方文档(简单的能看懂,复杂的属实看不懂,而且相关的解释也没找到,只有自己总结,总结的不是很全面,希望大佬能够指点迷津)

参数有四个
  • input (Tensor) – the input tensor. Must be at least 2-dimensional.
  • offset (int, optional) – which diagonal to consider. Default: 0 (main diagonal).
  • dim1 (int, optional) – first dimension with respect to which to take diagonal. Default: 0.
  • dim2 (int, optional) – second dimension with respect to which to take diagonal. Default: 1.
  1. 第一个参数就是传入的tensor类型的矩阵
  2. 第二个参数应该是偏移量(暂时觉得offset取0之外的值太复杂)
  3. 没整明白
  4. 没整明白
简单的实例,二维矩阵

在这里插入图片描述
直接引用官方的实例数据,此时传入的0,1数字应该是offset这一偏移量,来取对应的对角线元素。

复杂的实例,三维矩阵

实例1

a=torch.rand(2,3,4)
print(a)
b=torch.diagonal(a,dim1=0,dim2=1)
print(b)
print(b.shape)

得到的输出

tensor([[[0.5884, 0.0792, 0.1939, 0.3175],
         [0.5259, 0.4497, 0.2920, 0.6755],
         [0.9257, 0.2762, 0.1504, 0.7462]],

        [[0.6629, 0.0911, 0.8767, 0.4005],
         [0.8840, 0.2576, 0.2401, 0.1613],
         [0.3671, 0.0966, 0.2090, 0.6447]]])
tensor([[0.5884, 0.8840],
        [0.0792, 0.2576],
        [0.1939, 0.2401],
        [0.3175, 0.1613]])
torch.Size([4, 2])

输出数据对应原矩阵的位置:(dim1,dim2对应的位置不发生变化。**行变列不变**)
[(0,0,0)  (1,1,0)]
[(0,0,1)  (1,1,1)]
[(0,0,2)  (1,1,2)]
[(0,0,3)  (1,1,3)]

实例2

a=torch.rand(2,3,4)
print(a)
b=torch.diagonal(a,dim1=0,dim2=2)
print(b)
print(b.shape)

得到的输出

tensor([[[0.0454, 0.7792, 0.9889, 0.9855],
         [0.7824, 0.9162, 0.8135, 0.2376],
         [0.0701, 0.5878, 0.2338, 0.3420]],

        [[0.6993, 0.3415, 0.4539, 0.7812],
         [0.0872, 0.2570, 0.5155, 0.7736],
         [0.3762, 0.7464, 0.1557, 0.1252]]])
tensor([[0.0454, 0.3415],
        [0.7824, 0.2570],
        [0.0701, 0.7464]])
torch.Size([3, 2])
输出数据对应原矩阵的位置:
[(0,0,0)  (1,0,1)]
[(0,1,0)  (1,1,1)]
[(0,2,0)  (1,2,1)]


实例3

a=torch.rand(2,3,4)
print(a)
b=torch.diagonal(a,dim1=1,dim2=2)
print(b)
print(b.shape)

得到的输出

tensor([[[0.1378, 0.0287, 0.0254, 0.0355],
         [0.5497, 0.2742, 0.0664, 0.7303],
         [0.2270, 0.8366, 0.2908, 0.3661]],

        [[0.3372, 0.1650, 0.9361, 0.5833],
         [0.9213, 0.3715, 0.0806, 0.5747],
         [0.0688, 0.6735, 0.5550, 0.4947]]])
tensor([[0.1378, 0.2742, 0.2908],
        [0.3372, 0.3715, 0.5550]])
torch.Size([2, 3])

通过上面的观察,输出的数据维度与传入的dim1,dim2有很大的关系,最终得到的数据维度是:(除dim1/dim2之外的维度作为第一维度大小,剩下的作为第二维度大小<即dim1和dim2中较小的数>)<同样适用于更高维的矩阵中>

复杂的实例,四维矩阵

实例4

a=torch.rand(2,3,4,2)
print(a)
b=torch.diagonal(a,dim1=1,dim2=2)
print(b)
print(b.shape)

得到的输出:

tensor([[[[0.7196, 0.9937],
          [0.1257, 0.8227],
          [0.0641, 0.9343],
          [0.8150, 0.6029]],

         [[0.4693, 0.8988],
          [0.3097, 0.8774],
          [0.7828, 0.5973],
          [0.3847, 0.8274]],

         [[0.3126, 0.2040],
          [0.7447, 0.5588],
          [0.9778, 0.5571],
          [0.9159, 0.4530]]],


        [[[0.8360, 0.5035],
          [0.6402, 0.1219],
          [0.8775, 0.9003],
          [0.8240, 0.7149]],

         [[0.9807, 0.2547],
          [0.0715, 0.5177],
          [0.9933, 0.1935],
          [0.5069, 0.2203]],

         [[0.5923, 0.4335],
          [0.3306, 0.7048],
          [0.2834, 0.2013],
          [0.7158, 0.1417]]]])
tensor([[[0.7196, 0.3097, 0.9778],
         [0.9937, 0.8774, 0.5571]],

        [[0.8360, 0.0715, 0.2834],
         [0.5035, 0.5177, 0.2013]]])
torch.Size([2, 2, 3])

输出数据对应原矩阵的位置:
[(0,0,0,0)  (0,1,1,0)  (0,2,2,0)]
[(0,0,0,1)  (0,1,1,1)  (0,2,2,1)]
----------------------------------
[(1,0,0,0)  (1,1,1,0)  (1,2,2,0)]
[(1,0,0,1)  (1,1,1,1)  (1,2,2,1)]

注:

每一行数据中,dim1、dim2下标同时变动,变动的范围主要由dim1和dim2中较小的一个数决定(若dim1<dim2),则dim1、dim2变动的范围是从(0,dim1-1)到(0,dim1-1)
每一列数据中,dim1、dim2处的数据下标固定,变动的是剩下的维度大小,变动的范围也为该维度对应的范围

  • 11
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值