函数定义:
def diag(input: Tensor, diagonal: _int=0, *, out: Optional[Tensor]=None)
参数:
* input:tensor
* diagonal:选择输出的对角线,默认为0,即输出主对角线
实际上这个函数就是输出一个矩阵的对角线。若diagonal为正的话,输出主对角线右上角的副对角线;若diagonal为负的话,输出主对角线左上角的副对角线。
import torch
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(a)
# output:
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
print(torch.diag(a))
# output:
# tensor([1, 5, 9])
print(torch.diag(a, 1))
# output:
# tensor([2, 6])
print(torch.diag(a, -1))
# output:
# tensor([4, 8])
print(torch.diag(a, 2))
# output:
# tensor([3])
print(torch.diag(a, -2))
# output:
# tensor([7])