torch.triu()
作用:用于获取矩阵的上三角部分。
其定义为:
torch.triu(input, diagonal=0)
- input: 输入矩阵
- diagonal: 对角线之上为真值。0代表主对角线,正数表示对角线之上为真值,负数表示对角线之下为真值。
例如,如果你有一个矩阵:
取主对角线之上的元素。
tensor = [[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]
torch.triu(tensor)
# tensor([[1, 2, 3],
# [0, 5, 6],
# [0, 0, 9]])
取对角线之上的元素。
torch.triu(tensor, diagonal=1)
# tensor([[0, 2, 3],
# [0, 0, 6],
# [0, 0, 0]])
取对角线之下的元素。
torch.triu(tensor, diagonal=-1)
# tensor([[1, 0, 3],
# [4, 5, 0],
# [7, 8, 9 ]])
这个函数在许多情况下很有用:
- Masking - 将矩阵的一部分元素屏蔽/遮挡,只保留上三角或下三角部分。
- 对称矩阵矢量化 - 将上三角矩阵元素展平到向量中。
- 矩阵操作 - 仅对矩阵的上三角或下三角部分执行某些操作。
例子:
# Masking
tensor = torch.ones(3, 3)
torch.triu(tensor, diagonal=1) # 遮挡除对角线外的其他元素
# 矢量化对称矩阵
tensor = torch.tensor([[1, 2, 3],
[2, 4, 5],
[3, 5, 6]])
triu_idx = torch.triu_indices(3, 2) # 获取上三角索引
triu_vec = tensor[triu_idx]
# tensor([1, 2, 4, 5, 6])
# 矩阵操作
tensor = torch.ones(3, 3)
tensor[torch.triu_indices(3, 1)] += 2 # 仅对上三角部分元素加2
# tensor([[1., 2., 3.],
# [2., 2., 3.],
# [3., 3., 2.]])
总之,torch.triu()是一个用于获取矩阵上三角或下三角部分的简单但非常有用的函数,可以用于遮挡、矢量化和矩阵操作等目的。