pytorch进阶
Broadcast 广播 自动扩展
- python中自带Broadcast(广播)机制,当一个多维数据加低维数据时,会自动复制低维数据,帮低维数据自动补全至 与高维数据格式匹配
拼接与拆分
cat (拼接)
- 使用 torch.cat ( [ a , b ], dim = c ) 可以将 a 和 b 的 c 维度拼接在一起
- 注意:除了 参数中要拼接的维度dim,其他的维度size必须一致
- 下例中第 8 行就是因为 a1,a2 其他维度 size 不同所以报错
stack (拼接)
- torch.stack ( [ a , b ], dim = c ) 是在c处增加一个维度,该维度的size取决于要拼接的数据个数,此处为2
split 按长度拆分
- .split ( [a , b] ,dim = c ) -> 按照 a,b 的长度,将其拆分为 c 维度为 a 和 c 维度为 b 的两个值
- 参数 [a,b] 也可以为一个单值 d,即,.split ( d ,dim = c ) -> 将其 c 维度按 d 的长度等分 e 份,最后分为 e 个 c 维度为 d 的值
chunk 按个数拆分
- .chunk( a, dim=b ) 将 b 维度分为 a 份 ,即 分为 a 个 b 维度为 (b/a) 的值
数学运算
Add/minus/multiply/divide 加减乘除
- 直接使用 .add() 、.minus()、.multiply()、divide() 即可代替常规的加减乘除运算
- .eq(a,b) 是判断两个值是否相等
mm / matmul / @ 矩阵相乘
- .mm() 只支持二维矩阵相乘,不常用
- .matmul() 支持任意维度矩阵相乘,常用
- @ 是numpy中矩阵相乘的符号,和 matmul 用法一致
pow / sqrt / rsqrt 方根/根号/求导
- pow(a)代表 a 次方根
- sqrt / rsqrt 分别代表 根号和求导
- 可以在矩阵中使用,对矩阵中每个元素操作
exp / log 以e为底 / 以10为底
.floor() .ceil().trunc() .frac() 向上取整、向上取整、将整数与小数分开、取小数部分
clamp 梯度裁剪
- .clap(a) 所有不足 a 的值全部换为 a
- .clap(b,c) b为最小值,c为最大值,小于b的用b替换,大于c的用c替换
统计属性
norm 范数
- .norm( a ) -> 求a范数
- .norm( a, dim=b ) -> 对b维度求a范数,求完后,b维度消失
mean, sum, min, max, prod,argmax , argmin 平均、求和、最小、最大、累乘、最大索引、最小索引
- .argmax , .argmin 不同于 max,min , .argmax , .argmin会将矩阵拉成一维,返回最大 / 最小值 的索引
- 但可以采用 .argmax(dim=a) 的格式,获取a维度对应的最大值,如下例中,dim=1,即对列求最大值,得出每行对应的最大值索引(列的位置)
dim,keepdim max等函数的重要参数
- 对 max 加入 dim 参数,即,.max(dim=a) 可以返回该 a 维度上的最大值和其所在的索引
- 对 max 加入 dim 和 keepdim 参数,不光返回其值和索引,还能保持原来的维度(不加的话会自动消除维度,变成一维的输出)
topk,kthvalue 范围大小
- 使用 .topk( a,dim = b) 获得 b 维度上前 a 大的 a 个数及其索引 (默认最大值,参数 largest =false 后变为求最小值)
- 使用 .kthvalue( a,dim = b) 获得 b 维度上 倒数第 a 小的 数及其索引
gt,eq 大于、等于
- .gt(a,b) 等价于 > ,判断a中大于b的元素,正确返回1,否则返回0
- .eq(a,b) 判断 a是否等于 b ,正确返回1,否则返回0
- .equal(a,b) 判断 a是否等于b ,返回ture或false
where,gather
- torch.where(condition , x ,y ) 表示,当条件 condition 满足时,输出 x,不满足时输出 y
- 为了满足都是用GPU并行运算,常采用 where 一步解决
- 实际应用见下例
- torch.gather( input , dim, index, out=None ) 为查表操作,返回 index索引对应 在input上dim维度的值