常用shape操作
目录
文章目录
1. 增加/删除维度
> 删除: torch.tensor.squeeze(dim)
功能: 指定维度若为1, 则删除该维度, 否则不改变
举例
```python
t = torch.tensor.ones(2, 1, 3) # 假设t为tensor对象 t.shape = (2, 1, 3) dim的index依次为0, 1, 2
t = t.squeeze(dim=1) # 返回t.shape = (2, 3) dim1被去掉
t = t.squeeze(dim=0) # 仍然为(2, 1, 3) 因为 dim0 != 1```
```
> 增加: torch.tensor.unsqueeze(dim)
功能: 与squeeze相反 在dim上插入维度 新增的维度等于1
在指定位置上插入, 其余维度后移
举例
t = torch.tensor.zeros(2, 3) #shape=(2, 3)
t = t.squeeze(dim=0) # 在dim0上插入一个维度
# shape = (1, 2, 3)
2. 改变形状
> torch.tensor.reshape()
> torch.tensor.view()
举例
torch.randint(0, 20, (4, 5))
y = x.reshape(20)
z = x.view(2, 2, 5)
> 重点*两者的区别
详细介绍链接
- view创造的对象时共享空间的, 当原tensor某个数据改变后, 被view赋值的tensor对应位置的数据也要随之改变 (同一段数据, 不同的形状的引用)
- reshape创造的对象空间不共享, 原tensor的改变对新tensor没有影响
3. 推平flatten
nn.flatten()
4. transpose用于pytorch彩色图片在plt里输出
在plt中彩色图片的输入shape是(im_h, im_w, channels)
假设维度数组为(3, 64, 64)
transpose可以改变维度数组的顺序
如transpose(1, 2, 0)
可以把数据reshape成(64,64,3)