pytorch学习
torch的Tensor维度变换
view和reshape功能一样-不变的是数据本身,变维度改变对数据的理解
缩小/扩大维度,正的维度在索引之后插入,负的维度在索引之前插入
例子:给每幅图一个偏置bias
右边扩两次,左边一次,从[32]得到[1 32 1 1 ]
维度扩展:expend/repeat,expend在需要的时候复制数据,节约内存,推荐;repeat复了数据
expend扩展1->N可行,大于1需要指定策略
-1代表维持不变
repeat:函数里的数字都是重复次数
.t():矩阵转置
transpose:两两交换
transpose和permute都可能会导致内存打乱,要用contiguous函数连续起来
permute
Broadcast自动扩展
broadcast的扩张步骤理解
broadcast为自动扩张
图中第2行第2列的0 1 2 维度为1,即[3],broadcast首先 把[3]->[1 3],再进行扩张,步骤为:
- 在高维度上插入1
- 把所有shape为1的Tensor扩张为shape和a相同的shape
一个浮点数是4byte
理解broadcast的存在意义:
[5.0]存数据,[1]代表1维度
如果[1,32,8] + [5.0],一直相加会有维度不同的报错,因此5.0要进行扩张,不用broadcast的做法:
[1].unsqueeze(0).unsqueeze(0).expand_as(A)(expand换成repeat更是会加倍内存消耗,而broadcast 可以节省内存),使用broadcast
下图还是说扩张的步骤
例子
broadcast不行:
broadcast都是从低纬度开始(右边开始)
拼接和拆分
- Cat
- Stack
- Split
- Chunk
- Cat维度拼接
dim=?就拼接那个维度,数据都要补齐
Stack添加维度
stack维度一定要一致
Split
按长度/数目[3,1,1]按311拆分拆分
Chunk
按数目拆分
Tensor基本运算
+:add或者+
-:sub或者-
:mul或者*
/:div或者/——一个**/符号代表除法,两个//**代表整除
矩阵相乘
:是对应位置相乘
torch.mm只适用2d,不推荐
torch.matmul=@:矩阵相乘,@是matmul的重载
图中w为(512,784),a(4,784)@w(784,512)=(4,512),为什么w的512放第一位呢:Pytorch习惯把输出放第一位,t():转置,适合2d,高维用transpose
四维四维:前两位不变,矩阵乘后两位,前两位如果不一致就适用broadcast
次方函数
sqrt:平方根
rsqrt:倒数
e,log
以2为底:log2
以10为底:log10
floor(地板):往下走取整
ceil(天花板):往上走取整
trunc,frac(裁剪):裁剪为整数部分和小数部分
round:四舍五入
clamp(裁剪)
clamp(10):裁剪到最小值为10
clamp(0,10):裁剪到0-10