函数 | 作用 |
---|---|
norm | 范数 |
mean | 均值 |
sum | 总和 |
prod | 累乘 |
max min | 最大值 最小值(也能有索引) |
argmax argmin | 最大值最小值所在索引 |
topk | 前K大或前K小的数及索引 |
kthvalue | 第K小的数及索引 |
1、norm
这个norm是范数的意思,并不是normalize正则化。
vector norm 和 matrix norm计算不同:
对于向量来说,是向量中各元素的p次方和的p次根;而对于矩阵来说,是矩阵中对应位置元素的p次方和的p次根。
a = torch.full([8], 1).float()
b = a.view(2, 4)
c = a.view(2, 2, 2)
b
c
a.norm(1), b.norm(1), c.norm(1)
#a.norm(p, dim=?) 1范数就是所有数绝对值求和,都是8个1
a.norm(2), b.norm(2), c.norm(2)
#2范数是所有数平方和之后开平方根,都是根号8
b.norm(1, dim=1)
#b的shape:[2,4],dim=1就把第二个维度消掉,一行四列向量norm分别计算,结果shape为[2]
b.norm(2, dim=1)
c.norm(1, dim=0)
#c的shape为[2,2,2],dim=0就把第一个维度消掉,两个两行两列矩阵norm计算,结果shape为[2,2]
c.norm(2, dim=0)
# [2,2]shape的矩阵,第n行第n列的值为两矩阵中该位置元素平方和的平方根
2、mean sum prod min max
prod是累乘,其它几个函数都日常认识
a = torch.arange(8).view(2,4).float()
a.mean(), a.sum(), a.prod()
a.min(), a.max()
对于 argmin argmax:
如果不给定dim=?,就在使用argmin/max时会默认把张量打平,返回打平会所有元素中最小或最大值的索引,结果为单一标量。
如果给定dim=?,会消除该指定维度,并返回该维度的最大或最小值所在索引。
(max和min函数加不加dim效果同理)
a = torch.arange(8).view(2,4).float()
a.argmax()
a.argmax(dim=1)
补:keepdim
是max函数中除了dim之外的另一个参数
a.shape为(4,10)
a.max(dim=1)
a.max(dim=1,keepdim=True)
因为dim=1,要消除第二维度,
没有keepdim=True的情况下,结果shape为(4)
keepdim=True时,结果shape为(4,1),保持了原有的二维shape
3、top-k
相较于max(dim=?)只能输出目标维度最大值及其所在索引,top-k可根据k的大小,输出目标维度前k个元素值及其所在索引。
至于是最大的k个数还是最小的k个数,取决于是否添加参数largest=False
a.topk(2,dim=1)
输出每一行前二大的元素数值,及其在该行的索引
a.topk(2,dim=1, largest=False)
4、kthvalue
相较于max(dim=?)输出目标维度最大值及其所在索引,k-th输出的是第k小的数及其所在索引
a.kthvalue(4,dim=1) #keepdim加上看着更直观
#shape为(2,4),每一行长度都为4,所以k=4时,找第四小的数就是最大的数