import torch
import numpy as np
一、合并与切割
1.concat
a1 = torch.rand(4, 3, 32, 32)
a2 = torch.rand(5, 3, 32, 32)
# 将a1与a2的第0个维度相加
torch.cat([a1, a2], dim=0).shape
out:torch.Size([9, 3, 32, 32])
a2 = torch.rand(4, 1, 32, 32)
# 将a1与a2的第1个维度相加
torch.cat([a1, a2], dim=1).shape
out:torch.Size([4, 4, 32, 32])
# 如果维度不匹配,则会报错
torch.cat([a1, a2], dim=0).shape
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-7-c6c20381a457> in <module>
1 # 如果维度不匹配,则会报错
----> 2 torch.cat([a1, a2], dim=0).shape
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 3 and 1 in dimension 1 at ..\aten\src\TH/generic/THTensor.cpp:711
2.stack
a = torch.rand(32, 8)
b = torch.rand(32, 8)
# stack的作用就是将两个指定tensor合并成一个新的tensor,在指定维度扩展一个新的维度,其他地方保持不变
# 比如将第0个维度合并得到一个新的(2,32,8)的tensor
c = torch.stack([a, b], dim = 0)
c.shape
out:torch.Size([2, 32, 8])
# 合并后的(2,32,8)的tensor的第0的维度既是a, b两个tensor
print(torch.equal(torch.stack([a, b], dim = 0)[0], a))
out:True
torch.stack([a, b], dim = 1).shape
torch.Size([32, 2, 8])
# stack对维度和每一维度的通道数有严格的要求,若不合适则会报错
c = torch.rand(30, 8)
torch.stack([a, c], dim=0)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-12-2f6aa1a168d8> in <module>
1 # stack对维度和每一维度的通道数有严格的要求,若不合适则会报错
2 c = torch.rand(30, 8)
----> 3 torch.stack([a, c], dim=0)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 32 and 30 in dimension 1 at ..\aten\src\TH/generic/THTensor.cpp:711
3.split
# 之前一步将a,b合并,若想将其拆开,则使用split
# 因为合并是在第0维度的,拆开则在第0维度
aa, bb = c.split([1,1], dim = 0)
aa.shape, bb.shape
out:(torch.Size([1, 32, 8]), torch.Size([1, 32, 8]))
# 验证一下拆开后的确实和原本的一致
print(torch.equal(aa[0], a))
out:True
4.chunk
# 与上一步的需求一样,使用chunk的拆分,比如将C拆成两份
aa, bb = c.chunk(2, dim = 0)
aa.shape, bb.shape
out:(torch.Size([1, 32, 8]), torch.Size([1, 32, 8]))
d = torch.rand(8, 32, 8)
aa, bb, cc = d.chunk(3, dim = 0)
aa.shape, bb.shape, cc.shape
out:(torch.Size([3, 32, 8]), torch.Size([3, 32, 8]), torch.Size([2, 32, 8]))
# 使用split则可以指定维度的通道数大小
aa, bb, cc = d.split([4, 3, 1], dim = 0)
aa.shape, bb.shape, cc.shape
out:(torch.Size([4, 32, 8]), torch.Size([3, 32, 8]), torch.Size([1, 32, 8]))
二、运算
1.乘法
a = torch.tensor([[3,3],[3,3]])
b = torch.tensor([[1,1], [1,1]])
# a * b 是简单的对应位置的元素相乘
a * b
out:tensor([[3, 3],
[3, 3]])
# 矩阵相乘(不推荐,只适合2d的)
torch.mm(a, b)
out:tensor([[6, 6],
[6, 6]])
# 矩阵相乘
torch.matmul(a, b)
out:tensor([[6, 6],
[6, 6]])
# 矩阵相乘
a @ b
out:tensor([[6, 6],
[6, 6]])
# 多维乘法还是对后两维进行矩阵运算,相当于多个矩阵进行并行乘法计算
a = torch.rand(4, 3, 28, 64)
b = torch.rand(4, 3, 64, 28)
c = torch.matmul(a, b)
c.shape
out:torch.Size([4, 3, 28, 28])
2.其他
# 建立一个(2,2)的,全部都是3的tensor
a = torch.full([2,2], 3)
a.pow(2)
out:tensor([[9., 9.],
[9., 9.]])
aa = a ** 2
aa
out:tensor([[9., 9.],
[9., 9.]])
a = torch.tensor(3.14)
# 向下取整, 向上取整, 整数部分, 小数部分, 四舍五入
a, a.floor(),a.ceil(), a.trunc(), a.frac(), a.round()
out:(tensor(3.1400),
tensor(3.),
tensor(4.),
tensor(3.),
tensor(0.1400),
tensor(3.))
grad = torch.rand(2, 3) * 15
grad
out:tensor([[ 5.0516, 12.1264, 3.0216],
[12.0383, 7.7159, 13.8156]])
grad.max()
out:tensor(13.8156)
grad.median()
out:tensor(7.7159)
# 将小于10的全部命为10
grad.clamp(10)
out:tensor([[10.0000, 12.1264, 10.0000],
[12.0383, 10.0000, 13.8156]])
l1范数:
∥
x
∥
1
=
∑
i
=
1
n
∣
x
i
∣
\|x\|_{1}=\sum_{i=1}^{n}\left|x_{i}\right|
∥x∥1=i=1∑n∣xi∣
l2范数:
∥
x
∥
2
=
∑
i
=
1
N
x
i
2
\|x\|_{2}=\sqrt{\sum_{i=1}^{N} x_{i}^{2}}
∥x∥2=i=1∑Nxi2
a = torch.full([8], 1)
b = a.view(2, 4)
c = a.view(2, 2, 2)
out:tensor([[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]]])
a
out:tensor([1., 1., 1., 1., 1., 1., 1., 1.])
b
out:tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.]])
c
out:tensor([[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]]])
# 求l1范数
a.norm(1), b.norm(1), c.norm(1)
out:(tensor(8.), tensor(8.), tensor(8.))
# 求指定维度的l1范数
b.norm(1, dim = 1)
out:torch.Size([2])
# 求指定维度的l1范数
b.norm(1, dim = 0)
out:tensor([2., 2., 2., 2.])
# l2范数
a.norm(2), b.norm(2), c.norm(2)
out:(tensor(2.8284), tensor(2.8284), tensor(2.8284))