一.dim的定义
dim的不同值代表不同的维度,例如在二维张量中dim=0代表的是行,dim=1代表的是列。广泛的说,在多维张量()中,dim=0就是指,dim=n是指
二.例子
torch.sum()
input:输入的张量
dim:需要消减的维度
keepdim:输出张量中是否保存指定dim维的张量
eg1:
b = torch.arange(3 * 2 * 2).view(3, 2, 2)
print(b)
print(torch.sum(b, (1, 2)))
输出结果为:
这里的输出结果是按照第0维进行相加的,原因是因为dim=(1, 2)将这两维进行消减,从而根据剩下的一维进行求和计算。
eg2:
b = torch.arange(3 * 2 * 2).view(3, 2, 2)
print(b)
print(torch.sum(b, 1))
这里将第一维消减后还有两维,所以最终的输出结果将按照第0维以及第2维进行计算,所以最终张量的维度为3*2.