#cat
a = torch.rand(4,32,8)
b = torch.rand(5,32,8)
torch.cat([a,b],dim=0)#对第一个维度进行合并,其中其他维度大小需要一致#stack
a = torch.rand(32,8)
b = torch.rand(32,8)
torch.stack([a,b],dim=0)#两组数据进行堆叠,同时增加一个维度#split
a = torch.rand(32,8)
c = torch.stack([a,a,a,a],dim=0)
aa,bb = c.split([3,1],dim=0)#把c拆成两份
aa,bb = c.split(2,dim=0)#chunk
aa,bb = c.chunk(2,dim=0)
二、基本运算
a = torch.rand(3,4)
b = torch.rand(3,4)
torch.all(torch.eq(a+b,torch.add(a,b)))
torch.all(torch.eq(a-b,torch.sub(a,b)))
torch.all(torch.eq(a*b,torch.mul(a,b)))
torch.all(torch.eq(a/b,torch.div(a,b)))#矩阵乘法
a = torch.tensor([[3.,3.],[3.,3.]])
b = torch.ones(2,2)
torch.mm(a,b)
torch.matmul(a,b)
a@b
#次方
a = torch.full([2,2],3)
a.pow(2)
aa = a**2
aa.sqrt()#平方
aa.rsqrt()#平方根倒数
a = torch.exp(torch.ones(2,2))
torch.log(a)
torch.log2(a)
a = torch.tensor(3.14)
a.floor()#向下取整
a.ceil()#向上取整
a.trunc()#取整数
a.frac()#取小数
a.round()#四舍五入
grad = torch.rand(2,3)
grad.max()
grad.median()
grad.clamp(0,10)#超过这个范围的赋值为0或10
三、数据统计
#范数
a = torch.full([8],1.)
b = a.view(2,4)
c = a.view(2,2,2)
a.norm(1)
b.norm(1,dim =1)
a.prod()#a所有元素的乘积
a.argmax()
a.argmin()
a.max(dim=1,keepdim=True)
a = torch.randn(4,10)
a.topk(4,dim=1)#第一个维度里,最大的四个元素,并返回位置
a.kthvalue(8,dim=1)#第第一个维度里,找到第8小的元素,并返回位置
a >0
torch.gt(a,0)
四、其他筛选操作
cond = torch.tensor([[0.6769,0.7271],[0.8884,0.4163]])
a = torch.tensor([[0.,0.],[0.,0.]])
torch.where(cond>0.5,a,b)#大于0.5返回a的元素,否则返回b