import torch as t
a = t.arange(0, 6).view(2,3)
print(a)
a.cumsum(dim=0)
a = t.arange(0, 6).view(2,3)
print(a)
a.cumsum(dim=1)
对于二维输入a,dim=0(第1行不动,将第1行累加到其他行);dim=1(进入最内层,转化成列处理。第1列不动,将第1列累加到其他列;从第一列开始后面的每一列都是前面对应行元素的累加和),运行结果如下:
参数dim
,用来指定这些操作是在哪个维度上执行的。关于dim(对应于Numpy中的axis)有提供一个简单的记忆方式:
假设输入的形状是(m, n, k)
- 如果指定dim=0,输出的形状就是(1, n, k)或者(n, k)
- 如果指定dim=1,输出的形状就是(m, 1, k)或者(m, k)
- 如果指定dim=2,输出的形状就是(m, n, 1)或者(m, n)
size中是否有"1",取决于参数keepdim
,keepdim=True
会保留维度1
。注意,以上只是经验总结,并非所有函数都符合这种形状变化方式,如cumsum
。
同理,torch.cumprod()
dim=1时,第一列不变,后面的每列将前面列的元素乘起来,如 12 = 3 * 4 ,60 = 3 * 4 * 5 。
dim=0时,第一行不变,后面每行将前面行对应元素乘起来,如 10 = 2 * 5 。