pytorch中tril函数介绍

用法介绍

pytorchtril函数主要用于返回一个矩阵主对角线以下的下三角矩阵,其它元素全部为 0 0 0。当输入是一个多维张量时,返回的是同等维度的张量并且最后两个维度的下三角矩阵的。

torch.tril(input, diagonal=0, *, out=None) ⟶ \longrightarrow Tensor

  • input(tensor):表示输入的张量
  • diagonal (int, optional):表示对角线的位置

参数 d i a g o n a l \mathrm{diagonal} diagonal主要控制矩阵主对角线元素的位置。给定一个矩阵 A ∈ R d 1 × d 2 A\in \mathbb{R}^{d_1\times d_2} ARd1×d2,则这个矩阵的主对角线元素组成的集合为 { ( i , i ) ∣ i ∈ [ 0 , min ⁡ { d 1 , d 2 } − 1 ] } \{(i,i)| i \in [0,\min\{d_1,d_2\}-1]\} {(i,i)i[0,min{d1,d2}1]}当参数 d i a g o n a l = k \mathrm{diagonal}=k diagonal=k,且 k ∈ Z + k\in\mathbb{Z}^{+} kZ+时,则此时矩阵主对角线元素的集合为 { ( i , i + ∣ k ∣ ) ∣ i ∈ [ 0 , min ⁡ { d 1 , d 2 } ] − 1 } \{(i,i+|k|)| i \in [0,\min\{d_1,d_2\}]-1\} {(i,i+k)i[0,min{d1,d2}]1}当参数 d i a g o n a l = k \mathrm{diagonal}=k diagonal=k,且 k ∈ Z − k\in\mathbb{Z}^{-} kZ时,则此时矩阵主对角线元素的集合为 { ( i + ∣ k ∣ , i ) ∣ i ∈ [ 0 , min ⁡ { d 1 , d 2 } ] − 1 } \{(i+|k|,i)| i \in [0,\min\{d_1,d_2\}]-1\} {(i+k,i)i[0,min{d1,d2}]1}

程序代码

torch.tril函数具体的程序代码示例如下所示

>>> import torch
>>> a = torch.randn(3, 4)
>>> import torch
>>> a = torch.randn(3, 3)
>>> a
tensor([[ 0.4925,  1.0023, -0.5190],
        [ 0.0464, -1.3224, -0.0238],
        [-0.1801, -0.6056,  1.0795]])
>>> torch.tril(a)
tensor([[ 0.4925,  0.0000,  0.0000],
        [ 0.0464, -1.3224,  0.0000],
        [-0.1801, -0.6056,  1.0795]])
>>> b = torch.randn(4, 6)
>>> b
tensor([[-0.7886, -0.2559, -0.9161,  0.2353,  0.4033, -0.0633],
        [-1.1292, -0.3209, -0.3307,  2.0719,  0.9238, -1.8576],
        [-1.1988, -1.0355, -1.2745, -1.7479,  0.3736, -0.7210],
        [-0.3380,  1.7570, -1.6608, -0.4785,  0.2950, -1.2821]])
>>> torch.tril(b)
tensor([[-0.7886,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-1.1292, -0.3209,  0.0000,  0.0000,  0.0000,  0.0000],
        [-1.1988, -1.0355, -1.2745,  0.0000,  0.0000,  0.0000],
        [-0.3380,  1.7570, -1.6608, -0.4785,  0.0000,  0.0000]])
>>> torch.tril(b, diagonal=1)
tensor([[-0.7886, -0.2559,  0.0000,  0.0000,  0.0000,  0.0000],
        [-1.1292, -0.3209, -0.3307,  0.0000,  0.0000,  0.0000],
        [-1.1988, -1.0355, -1.2745, -1.7479,  0.0000,  0.0000],
        [-0.3380,  1.7570, -1.6608, -0.4785,  0.2950,  0.0000]])
>>> torch.tril(b, diagonal=-1)
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-1.1292,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-1.1988, -1.0355,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.3380,  1.7570, -1.6608,  0.0000,  0.0000,  0.0000]])
>>> torch.tril(b, diagonal=2)
tensor([[-0.7886, -0.2559, -0.9161,  0.0000,  0.0000,  0.0000],
        [-1.1292, -0.3209, -0.3307,  2.0719,  0.0000,  0.0000],
        [-1.1988, -1.0355, -1.2745, -1.7479,  0.3736,  0.0000],
        [-0.3380,  1.7570, -1.6608, -0.4785,  0.2950, -1.2821]])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

道2024

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值