一些自己常用的pytorch函数整理
建议直接Ctrl+f搜索
torch.unsqueeze 维度增加
- torch.unsqueeze(torch.Tensorf,axis)
- 常用形式torch.unsqueeze(x,0)#最外维度+1
这个函数功能等价于numpy的expand_dim
torch.Tensor.mm矩阵相乘
- 设a=torch.ones(1,2)
- 设b=torch.rand(2,2)
- 矩阵点乘直接a*b就行
- 矩阵乘法就是a.mm(b)
torch.topk(a,k) 获取a中前k个最大的值和下标
- 对于多维度的,比如二维,结果如下
- 返回的是两个 tensor 所以我们可以写如下代码
-
values,indexs=torch.topk(a,k)
torch.tensor几种初始化
- #标准正太分布均值未0方差为1中随机抽取一组随机数
-
torch.randn(*sizes,dtype)
- 均匀分布【0,1)中均匀分布中抽取一组随机数
-
torch.rand(*sizes,out=None)
- 离散正太分布 means均值 std方差
-
torch.normal(means,std,size=(),dtype)
- 线性间距向量均匀间隔
-
torch.linspace(start,end,steps=100)
torch.view()改变形状
torch.tensor().numpy 类型转换
numpy 变成 torch
tensor=torch.form_numpy(numpy_data)
torch.stack
torch.stack()仅接受torch.tensors 注意Tensor和tensors是不一样的东西
tensors是个列表
代码如下
torch.cat
这个函数的功能和np.concatenate一毛一样不多解释了
也就后面np的axis参数换成了 dim
torch.cat((a,b),dim=0)
torchvision.transforms.compose
加载图像时候把几个变换操作串联起来
np.transpose(imgs,(1,2,0))
交换维度,常用于画图时候用,因为 torch的维度方式和plt不同
torch是batch channel size
plt是channel size batch