tensor torch相关操作
基础运算
加
import torch
a = torch.rand(3,4)
b= torch.rand(3,4)
c=a+b #或
c=torch.add(a,b)#维度不同时会对broadcasting
减
d=a-b#或
d=torch.sub(a,b)
乘
#element wise,对应元素乘
e=a*b
e=torch.mul(a,b)
二维矩阵相乘
e=torch.mm(a,b.t) #(a, b)* ( b, c ) = ( a, c )
多维元素最后两个维度乘
e=torch.matmul(a,b)
点乘,内积
一维张量对应位置相乘,返回标量
torch.dot(x,y)
除,对应元素除
f=a/b
f=div(a,b)
幂运算
g=a.pow(n)
开方
h=sqrt()
h=rsqrt()#倒数
对数
k=torch.exp(torch.ones(2,3))#得到全是e的2*3矩阵
i=log2(a)
i=log10(a)
i=log(a)
近似值
l=a.floor()#取上
l=a.ceil()#取下
l=a.trunc()#取整
l=a.frac()#取小数
最大最小中位数均值众数标准差方差
a.max()
a.min()
a.median()中位数
a.mean()均值
a.mode()众数
a.std()标准差
a.var()方差
裁剪运算
torch.clamp(input,min,max, out=None)超出的限制为min,max
给定阈值截取部分tensor
拿出大于阈值的结果
import torch
a=torch.randn(8,3)
print(a)
#b用来拿满足条件的索引,这里判断的是索引2大于0.5
b=a[...,2]>0.5
print(b)
a=a[b]
print(a)
tensor([[ 2.0951, -0.4598, 0.7132],
[-0.6455, 0.7299, -0.2005],
[-0.4402, -1.5987, 0.9192],
[ 0.5835, -0.0969, -0.3439],
[ 0.5355, -0.8484, 1.3264],
[-0.5404, -0.6408, -0.2357],
[-0.2434, 0.5452, 0.6899],
[-0.1293, 0.3484, 0.0673]])
tensor([ True, False, True, False, True, False, True, False])
tensor([[ 2.0951, -0.4598, 0.7132],
[-0.4402, -1.5987, 0.9192],
[ 0.5355, -0.8484, 1.3264],
[-0.2434, 0.5452, 0.6899]])

本文详细介绍TensorTorch中的基础运算操作,包括加减乘除等基本数学运算、近似值处理、最大最小值及中位数计算等,并提供如何通过设定阈值来筛选张量元素的具体实例。
1750

被折叠的 条评论
为什么被折叠?



