1、拼接与拆分(merge or split)
常用函数
=====Cat(拼接)
特点:不会增加新的维度
torch.cat([a,b], dim=0):将张量a和b在0维上进行拼接,其余维必须相等
=====Stack(拼接)
特点:会增加新的维度
torch.stack([a,b], dim=0)
stack会创建一个新的维度,由dim设置
例子中,会在第0维设置2:取0表示a,取1表示b
=====Split(拆分)
特点:根据长度拆分,给定每个拆分后张量大小
a.split(1, dim=0):在第0维拆分,每个大小为1
a.split([2,1], dim=0):在第0维拆分,拆成一个长度为2的,一个长度为1的
=====Chunk(拆分)
特点:根据数量拆分,给定要拆成多少个张量
a.chunk(2, dim=0):在第0维拆分,平均拆成2个
2、基本运算
数学运算
=====add/sub/mul/div
a*b和torch.mul(a,b):都是元素相乘
torch.matmul(a,b):a和b需要满足矩阵相乘法则
=====matmul(矩阵相乘)
torch.mm():只适合二维
torch.matmul()和@:都适合
特别对于高维张量:只对后面二维相乘
(4,3,28,64)@(4,3,64,32) ⇒ (4,3,28,32)
如果前面两个维度值不同,但符合broadcasting机制,可以使用broadcasting
(4,3,28,64)@(4,1,64,32) ⇒ (4,3,28,32)
=====pow/sqrt/rsqrt
a.pow(n):n次方
a.sqrt():开方
a.rsqrt():开方之后取倒数
=====torch.exp/torch.log
=====近似
a.floor():向下取整
a.ceil():向上取整
a.trunc():取整数
a.frac():取小数
a.round():四舍五入
=====clamp限制(一般用于限制参数梯度值)
grad.clamp(10):最小值为10,小于10的全部变为10
grad.clamp(0,10):最小值为0,最大值为10