目录
为什么要写这篇文章
在学习pytorch的过程中对张量中元素的操作是必不可少的,然后求和也是非常常用的,但是会发现求和操作有非常多的函数。首先就是内置sum(tensor),你还会发现也可以使用torch.sum(tensor)和tensor.sum()进行求和,那么这些函数之间有什么区别呢?接下来我将从一个最简单的例子进行说明
举例说明
sum(tensor)
# 由于针对的元素是tensor类型,所以先导入pytorch
import torch
# 创建张量
a = torch.arange(3)
b = torch.arange(9).reshape((3,3))
# 查看两个张量的形状
print(a.shape) # torch.Size([3])
print(b.shape) # torch.Size([3, 3])
print(sum(a)) # tensor(3)
print(sum(b)) # tensor([ 9, 12, 15])
从上面的输出可以看出,如果如果维度不一样的话sum(tensor)输出的是不一样的,那么他是怎么进行计算的呢?从上面两个例子还看不出规律,再来看两个例子
c = torch.arange(8).reshape((2,2,2))
d = torch.arange(16).reshape((2,2,2,2))
print(sum(c))
# tensor([[[0, 1],
# [2, 3]],
#
# [[4, 5],
# [6, 7]]])
print(sum(d))
# tensor([[[ 8, 10],
# [12, 14]],
#
# [[16, 18],
# [20, 22]]])
从这个张量c和d中就可以明显的看出一些东西,其中最直观的就是维度发生了变化,而且使用tensor.ndim输出可以看出sum(tensor)相比tensor来说维度减小了一,所以就可以看出计算规律了。
规律:假设tensor的维度是n,那么sum(tensor)的计算方式就是在n-1的维度上对应位置的元素相加从而生成一个新的张量,生成的张量维度是n-1
简单计算:首先就是把输出的第一个括号和最后一个括号去掉,这个时候就发现会分成多个独立的张量,而这些张量的维度是原始张量的维度-1,然后将这些独立的张量的对应位置的元素相加就得到了结果
torch.sum(tensor,axis)
在学会使用sum(tensor)之后我们来看看torch.sum(tensor,axis),可以看出在这个方法中我传入了一个参数叫做axis,让我们向连接一下axis的作用
axis讲解
以二维张量和三维张量为例,来看一下axis相同时不同维度输出的数值是否相等
tensor1 = torch.arange(6).reshape((2,3))
tensor2 = tensor1.reshape((1,2,3))
print(tensor1)
print(tensor1[0])
print(tensor2[0])
# 输出的数值是不一样的,也可以看一下axis=1的数值是否相同呢
print(tensor1[:,:2]) # 只看部分元素
print(tensor2[:,:2])
# 结果如下:
tensor1:tensor([[0, 1, 2],
[3, 4, 5]])
tensor1[0]:tensor([0, 1, 2])
tensor2[0]:tensor([[0, 1, 2],
[3, 4, 5]])
tensor1[:,:2]:tensor([[0, 1],
[3, 4]])
tensor2[:,:2]tensor([[[0, 1, 2],
[3, 4, 5]]])
从上诉的代码中可以看出多维数组中的axis=0并不是代表的就是行,axis并不是代表的是列,而是在多维张量中axis = 0代表的是第一层元素,axis代表的是相对于axis=1来说往里面再走一层,依次类推。
torch.sum(tensor,axis)
在了解完axis之后,torch.sum(tensor,axis)就很简单了,其实计算的方法和sum()的计算是一样的,只不过sum(tensor)其实效果适合torch.sum(tensor,axis = 0)是一样的。在计算torch.sum(tensor,axis = n)其中n的取值范围是[0,ndim-1),n只能取整数。
上面的介绍可以是比较抽象的,来看一个具体的例子,以比较高维的张量来看,这里维度取四,其他高纬度的求和方法也是一样的。
d = torch.arange(32).reshape((2,2,4,2))
print(f"d的数值{d}")
print(f"axis=0:{torch.sum(d,axis = 0)}")
print(f"axis=1:{torch.sum(d,axis = 1)}")
print(f"axis=2:{torch.sum(d,axis = 2)}")
print(f"axis=3:{torch.sum(d,axis = 3)}")
输出:
d的数值tensor([[[[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[10, 11],
[12, 13],
[14, 15]]],
[[[16, 17],
[18, 19],
[20, 21],
[22, 23]],
[[24, 25],
[26, 27],
[28, 29],
[30, 31]]]])
axis=0:tensor([[[16, 18],
[20, 22],
[24, 26],
[28, 30]],
[[32, 34],
[36, 38],
[40, 42],
[44, 46]]])
axis=1:tensor([[[ 8, 10],
[12, 14],
[16, 18],
[20, 22]],
[[40, 42],
[44, 46],
[48, 50],
[52, 54]]])
axis=2:tensor([[[ 12, 16],
[ 44, 48]],
[[ 76, 80],
[108, 112]]])
axis=3:tensor([[[ 1, 5, 9, 13],
[17, 21, 25, 29]],
[[33, 37, 41, 45],
[49, 53, 57, 61]]])
我从每个得到的结果中的每个元素是怎么来的进行解析:
axis = 0:首先要知道axis=0取到的是什么元素,得到的元素为:[[[ 0, 1], [ 2, 3], [ 4, 5], [ 6, 7]], [[ 8, 9], [10, 11], [12, 13], [14, 15]]]和[[[16, 17], [18, 19], [20, 21], [22, 23]], [[24, 25], [26, 27], [28, 29], [30, 31]]]两个维度是3的张量(原来的张量梯度为4),由于这两个张量是从原张量中分离出来的,所以这两个张两个的形状以及张量内的元素的类型都是相等的,所以将对应位置的元素相加就得到了结果,[[[16, 18], [20, 22], [24, 26], [28, 30]], [[32, 34], [36, 38], [40, 42], [44, 46]]]。16 = 0+16,18 = 1+17,20 = 2+18等一次类推即可
axis = 1:首先还是要知道axis=1会得到那几个元素。编号1:[[ 0, 1], [ 2, 3], [ 4, 5], [ 6, 7]],编号2: [[ 8, 9], [10, 11], [12, 13], [14, 15]],编号3:[[[16, 17], [18, 19], [20, 21], [22, 23]],编号4:[[24, 25], [26, 27], [28, 29], [30, 31]]。其中编号1和2是在同一个三维元素中分离出来的也就是axis=0时候的第一个元素,编号3和编号4的元素同理。将编号1和编号2,编号3和编号4对应位置的元素相加,然后再增加一个维度把这两个二维张量组织起来就得到了结果,结果是三维张量的。
axis = 2:同理可得:编号1:[ 0, 1],编号2:[ 2, 3],编号3:[ 4, 5],编号4:[ 6, 7],这四个编号是一组组成axis=1的编号1中的元素,同理可以得到编号5-8是一组,编号9-12是一组,13-16是一组。然后其中的第一组和第二组组成第一大组,然后第三组和第四组组成第二大组,然后将每组的元素纵向相加即可,然后的到的结果按照组的形式添加维度,知道添加维度到原张量维度减一。
axis = 3:就是axis = 2中横向相加得到元素然后添加维度即可
注意:描述的可能不是非常清楚,大家可以多做一些例子,然后结合本文章进行理解
tensor.sum()
这个也是张量中的一个求和方法,但是这个方法是比较简单的,就是返回张量中所有元素的和,也就是说得到的结果是只含有一个数值的张量
# 还是使用上面谈到的d张量
# d = torch.arange(32).reshape((2,2,4,2))
print(d.sum()) # tensor(496)
总结
tensor.sum()不管tensor的维度形状,返回的是tensor中所有元素的和作为结果张量的唯一元素
torch.tensor(tensor,axis = 0)和sum(tensor)的作用是相同的
多维张量的torch.sum(tensor,axis)尽量理解,但是常用的axis是0,1和tensor.ndim -1,其他的数值可以看情况学习。
本文详细解释了PyTorch中sum(),torch.sum(tensor)和tensor.sum()这三个张量求和函数的区别,重点讨论了axis参数如何影响结果,并提供了实例分析。
6739

被折叠的 条评论
为什么被折叠?



