pytorch doc介绍 torch.eye()生成一个n×m的单位矩阵,主对角线全为1,其余位置为0。 import torch torch.eye(3) # tensor([[ 1., 0., 0.], # [ 0., 1., 0.], # [ 0., 0., 1.]])