文章目录
前言
Pytorch学习笔记第三篇,关于Tensor的合并(cat/stack)、分割(split/chunk)与基本运算。
一、合并Cat/Stack
1.Cat
Tensor中cat是contract的缩写,代表着两个张量Tensor在制定维度上进行合并,这就要求这两个张量Tensor在其余维度的长度一致。
代码如下(示例):
import torch
#1 cat
a=torch.rand(5,32,8)
b=torch.rand(4,32,8)
c=torch.cat([a,b],dim=0) #将a,b在0维上合并 ->[9.32.8]
2.Stack
stack也可以用于Tensor的合并,但区别于cat,stack会在指定索引上创造一个新维度,因此stack要求原来的两个张量Tensor必须维度形状完全一致。
代码如下(示例):
#2 stack 创造一个新的维度
a=torch.rand(32,8)
b=torch.rand(32,8)
c=torch.stack([a,b],dim=0) #在0维增加新的维度->[2,32,8],其中c[0]=a,c[1]=b
#dim!=0时则是后续维度与前面Tensor的后续部分按维度匹配。
二、分割Split/Chunk
1.Split
split对张量Tensor在指定维度上进行划分,split的划分是按照长度进行划分的,因此输入的参数为分割后各张量的长度。
代码如下(示例):
#3 split 按长度拆分
a=torch.rand(4,2,2)
aa,bb=a.split(2,dim=0) #dim=0上拆分为2个长度为2的张量,aa=a[:2],bb=a[2:]
aa,bb,cc=a.split([2,1,1],dim=0) #dim=0上拆分3个,aa=a[:2],bb=a[2],cc=a[3]
2.Chunk
chunk与split不同的在于,chunk需要指定的是分解后张量的个数,而非结果张量的长度
代码如下(示例):
#4 chunk 按个数拆分
a=torch.rand(8,2,2)
aa,bb=a.chunk(2,dim=0) #aa,bb->[4,2,2]
三、基本运算
1.加减乘除
对应于torch中add、sub、mul、div,且已经重载为+、-、*、/,具有广播机制。
代码如下(示例):
#1 基本运算符+-*/ 对应元素运算,拥有广播机制
a=torch.rand(5,3)
b=torch.rand(3)
a+b
a-b
a*b
a/b
2.矩阵乘法mm/@/matmul
矩阵乘法的运算分为mm、@、matmul。
其中mm只能作用于2维矩阵。
matmul可以作用于dim>=2的矩阵,其机理为最后两维做矩阵乘法,前面的维数保持不变或广播。
@是matmul的运算符重载,使用方便。
代码如下(示例):
#2 矩阵乘法 mm/matmul/@
c=torch.rand(5,3)
d=torch.rand(3,4)
e=torch.rand(784,3)
torch.mm(c,d) #只适用于dim=2,不建议使用
torch.matmul(c,d) #适用于任一情况
c@e.t() #运算符重载@为矩阵乘法
#dim>2时,最后两维做矩阵乘法@,前面的维数保持不变或广播
e=torch.rand(4,3,7,8)
f=torch.rand(4,1,8,9)
e@f #[4,3,7,9]
torch.matmul(e,f)
3.幂运算**
幂运算可以采用pow、sqrt、rsqrt进行,也可以采用重载运算符**进行。
pow可以指定幂指数
sqrt求平方根
rsqrt求平方根的倒数
代码如下(示例):
#3 幂运算**/pow
a**2 #平方
a**(0.5) #开方
a.pow(2)
a.sqrt()
a.rsqrt()
4.指数exp/对数log
exp求以e为底的指数结果。
log求自然对数ln运算结果。
log10、log2等为底数不同的对数运算
代码如下(示例):
#4 exp/log
torch.exp(a) #以e为底求幂
torch.log(a) #取对数ln
torch.log10(a) #取对数lg
5.近似floor/ceil/round/trunc/frac
对浮点Tensor进行近似。
floor:向下取整
ceil:向上取整
round:四舍五入
trunc:取整数部分
frac:取小数部分
代码如下(示例):
#5 近似floor/ceil/round/trunc/frac
g=torch.tensor(3.1415)
g.floor() #向下取整3.
g.ceil() #向上取整4.
g.trunc() #取整数部分3.(浮点)
g.frac() #取小数部分0.1415
g.roung() #四舍五入
6.裁剪(归化)clamp
clamp作用于张量每一个元素,将会指定范围,并将超出范围[min,max]的数据规范到min、max
代码如下(示例):
#6 clamp裁剪(将超出范围的数规范到min/max)
h=torch.rand(3,3)*10 #0-10之间随机浮点
h1=h.clamp(5) #小于5的归化为5
h2=h.clamp(6,7) #小于6的归化为6,大于7的归化为7
总结
以上是Tensor的合并、分割与基本运算,下一篇计划为Tensor统计操作与高级操作。
2021.2.18