Pytorch学习(4):Tensor统计、where与gather


前言

Pytorch学习笔记第四篇,关于Tensor的统计(max、min、mean等)、where、gather。


一、统计

1.范数norm

从目前学习的知识来看,pytorch提供p范数,基本都是各元素的幂之和开根号,似乎没有学到矩阵范数。
同时norm函数可以指定dim进行求范数,得到每一个dim向量的范数。

代码如下(示例):

import torch

sample=torch.ones(8)
s1=sample.view([4,2])
s2=sample.view([2,2,2])
#1 范数
sample.norm(1)  #1范数 8
sample.norm(2)  #2范数 2√2

s1.norm(1,dim=1) #在dim=1(列)求范数,结果形状为[4],结果为[2,2,2,2]
s2.norm(2,dim=0) #在dim=0求范数,结果形状为[2,2],结果为[[1.414,1.414],[1.414,1.414]]

2.max/min/mean/sum/prod

Tensor元素的重要特征:
max:最大值
min:最小值
mean:平均值
sum:和
prod:积

代码如下(示例):

#2 重要数据特征max、min、mean、sum、prod
sample2=torch.rand(3,3)
sample2.max() #最大值
sample2.min() #最小值
sample2.mean() #均值
sample2.sum() #和
sample2.prod() #积

3.argmax/argmin/dim/keepdim

argmax/argmin可以得到Tensor最大值与最小值的索引,但不指定dim的情况下,argmax和argmin都会进行打平view,最终得到的索引也是打平后的索引。
指定维度dim,则会返回该维度的最值索引,而选定keepdim=True则会保持索引结果与原张量Tensor的dim一致(不是shape一致)。

代码如下(示例):

#3 argmax,argmin最值索引
#不指定维度,argmax和argmin都会打平tensor,索引也是在打平后的索引
sample2.argmax() #返回最大值的索引
sample2.argmin() #最小值索引
#指定维度,返回该维度上各个最值索引
sample2.argmax(dim=1) #返回形状为[3]的tensor,共有三个索引
sample2.argmin(dim=0) #同理
#keepdim=True会保持dim=2不变
sample2.argmax(dim=1,keepdim=True) #返回形状为[3,1]的tensor
sample2.max(dim=1,keepdim=True) #返回[3,1]形状的值tensor与[3,1]形状的索引tensor

4.Topk/kthvalue

topk返回前k个元素的张量,以及这些元素的索引张量,可以通过选定dim指定维度,默认为前k个最大的;选定largest=False,即选前k个最小的。
kthvalue选择第k个元素,返回张量与索引张量。

代码如下(示例):

#4 Topk/kthvalue
sample2.topk(2,dim=1) #返回在dim=1上返回最大的两个,形状为[3,2],也返回索引
sample2.topk(2,dim=1,largest=False) #返回最小的2个
sample2.kthvalue(2,dim=1) #返回dim=1第二大的值与索引

5.比较/eq/equal

对两个形状一致的张量Tensor进行比较,得到新的张量,新Tensor为bool型,各个元素为两Tensor比较结果。
eq与比较类似,返回bool型张量
equal返回bool值(不是张量),判断两个张量是否完全一致。

代码如下(示例):

#5 >,<,>=,<=,eq,equal
sample3=torch.randint(0,5,[5,5])
sample4=torch.randint(0,5,[5,5])
sample4>sample3 #与numpy类似,返回一个各个元素比较的bool矩阵/向量

torch.eq(sample4,sample3) #返回各个元素比较的bool矩阵
torch.eq(sample3,sample4) #返回True/False,只有完全一样返回True

一、高级操作where/gather

1.条件where

where接受条件cond张量与A、B张量,满足cond时,结果对应元素取A中对应元素,反之取B中对应元素。
代码如下(示例):

#1 where(cond,A,B) 满足条件cond的时候,结果取A中对应元素,反之取B
cond=torch.rand(3,3)
A=torch.rand(3,3)
B=torch.rand(3,3)
result=torch.where(cond>0.5,A,B)

2.gather

按照张量index取张量input指定dim的元素生成新tensor,可GPU加速

代码如下(示例):

#2 gather(按照张量index取张量input制定dim的元素生成新tensor,可GPU加速)
input=torch.rand(5,7)
index=torch.randint(0,7,[4,1])
result=torch.gather(input,dim=1,index=index)  #结果是result[i][j]=input[i][index[i][j]]

总结

以上是Tensor的统计、where、gather,下一篇计划为Pytorch中梯度与优化。
2021.2.19

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值