使用工具时dim设置容易弄混,做一个小的总结
关于dim的设置感觉最简易的一种理解是:只有dim指定的维度是可变的,其他都是固定不变的。详情见:Pytorch笔记:维度dim的定义及其理解使用_Activewaste-CSDN博客
如果dim=0,即行是可变的,列数不变
torch.arange(0,6).view(2,3)
输出:tensor([[0, 1, 2], [3, 4, 5]])
torch.arange(0,6).view(2,3).sum(dim=0, keepdim=True)
输出:tensor([[3, 5, 7]]) 对应size 1,3
对应列的话
torch.arange(0,6).view(2,3).sum(dim=1, keepdim=True)
输出 : tensor([[ 3], [12]]) 对应size 2,1
如果是三维的情况:
torch.arange(0,30).view(2, 3, 5)
tensor([[[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14]], [[15, 16, 17, 18, 19], [20, 21, 22, 23, 24], [25, 26, 27, 28, 29]]])
torch.arange(0,30).view(2, 3, 5).sum(dim=0, keepdim=True)
输出:tensor([[[15, 17, 19, 21, 23], [25, 27, 29, 31, 33], [35, 37, 39, 41, 43]]]) 对应size 1, 3, 5
torch.arange(0,30).view(2, 3, 5).sum(dim=1, keepdim=True)
输出:tensor([[[15, 18, 21, 24, 27]], [[60, 63, 66, 69, 72]]]) 对应size 2, 1, 5
torch.arange(0,30).view(2, 3, 5).sum(dim=2, keepdim=True)
输出:tensor([[[ 10], [ 35], [ 60]], [[ 85], [110], [135]]]) 对应size 2, 3, 1