用法介绍
pytorch中tril函数主要用于返回一个矩阵主对角线以下的下三角矩阵,其它元素全部为 0 0 0。当输入是一个多维张量时,返回的是同等维度的张量并且最后两个维度的下三角矩阵的。
torch.tril(input, diagonal=0, *, out=None) ⟶ \longrightarrow ⟶Tensor
- input(tensor):表示输入的张量
- diagonal (int, optional):表示对角线的位置
参数 d i a g o n a l \mathrm{diagonal} diagonal主要控制矩阵主对角线元素的位置。给定一个矩阵 A ∈ R d 1 × d 2 A\in \mathbb{R}^{d_1\times d_2} A∈Rd1×d2,则这个矩阵的主对角线元素组成的集合为 { ( i , i ) ∣ i ∈ [ 0 , min { d 1 , d 2 } − 1 ] } \{(i,i)| i \in [0,\min\{d_1,d_2\}-1]\} {(i,i)∣i∈[0,min{d1,d2}−1]}当参数 d i a g o n a l = k \mathrm{diagonal}=k diagonal=k,且 k ∈ Z + k\in\mathbb{Z}^{+} k∈Z+时,则此时矩阵主对角线元素的集合为 { ( i , i + ∣ k ∣ ) ∣ i ∈ [ 0 , min { d 1 , d 2 } ] − 1 } \{(i,i+|k|)| i \in [0,\min\{d_1,d_2\}]-1\} {(i,i+∣k∣)∣i∈[0,min{d1,d2}]−1}当参数 d i a g o n a l = k \mathrm{diagonal}=k diagonal=k,且 k ∈ Z − k\in\mathbb{Z}^{-} k∈Z−时,则此时矩阵主对角线元素的集合为 { ( i + ∣ k ∣ , i ) ∣ i ∈ [ 0 , min { d 1 , d 2 } ] − 1 } \{(i+|k|,i)| i \in [0,\min\{d_1,d_2\}]-1\} {(i+∣k∣,i)∣i∈[0,min{d1,d2}]−1}
程序代码
torch.tril函数具体的程序代码示例如下所示
>>> import torch
>>> a = torch.randn(3, 4)
>>> import torch
>>> a = torch.randn(3, 3)
>>> a
tensor([[ 0.4925, 1.0023, -0.5190],
[ 0.0464, -1.3224, -0.0238],
[-0.1801, -0.6056, 1.0795]])
>>> torch.tril(a)
tensor([[ 0.4925, 0.0000, 0.0000],
[ 0.0464, -1.3224, 0.0000],
[-0.1801, -0.6056, 1.0795]])
>>> b = torch.randn(4, 6)
>>> b
tensor([[-0.7886, -0.2559, -0.9161, 0.2353, 0.4033, -0.0633],
[-1.1292, -0.3209, -0.3307, 2.0719, 0.9238, -1.8576],
[-1.1988, -1.0355, -1.2745, -1.7479, 0.3736, -0.7210],
[-0.3380, 1.7570, -1.6608, -0.4785, 0.2950, -1.2821]])
>>> torch.tril(b)
tensor([[-0.7886, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[-1.1292, -0.3209, 0.0000, 0.0000, 0.0000, 0.0000],
[-1.1988, -1.0355, -1.2745, 0.0000, 0.0000, 0.0000],
[-0.3380, 1.7570, -1.6608, -0.4785, 0.0000, 0.0000]])
>>> torch.tril(b, diagonal=1)
tensor([[-0.7886, -0.2559, 0.0000, 0.0000, 0.0000, 0.0000],
[-1.1292, -0.3209, -0.3307, 0.0000, 0.0000, 0.0000],
[-1.1988, -1.0355, -1.2745, -1.7479, 0.0000, 0.0000],
[-0.3380, 1.7570, -1.6608, -0.4785, 0.2950, 0.0000]])
>>> torch.tril(b, diagonal=-1)
tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[-1.1292, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[-1.1988, -1.0355, 0.0000, 0.0000, 0.0000, 0.0000],
[-0.3380, 1.7570, -1.6608, 0.0000, 0.0000, 0.0000]])
>>> torch.tril(b, diagonal=2)
tensor([[-0.7886, -0.2559, -0.9161, 0.0000, 0.0000, 0.0000],
[-1.1292, -0.3209, -0.3307, 2.0719, 0.0000, 0.0000],
[-1.1988, -1.0355, -1.2745, -1.7479, 0.3736, 0.0000],
[-0.3380, 1.7570, -1.6608, -0.4785, 0.2950, -1.2821]])