一文捋清【reshape、view、rearrange、contiguous、transpose、squeeze、unsqueeze】——python & torch

一文捋清【reshape、view、rearrange、contiguous、transpose、squeeze、unsqueeze】

1. reshape

reshape() 函数: 用于在不更改数据的情况下为数组赋予新形状。
注意: 用于低维度转高维度

c = np.arange(6)
print("** ", c)
c1 = c.reshape(3, -1)
print("** ", c1)
c2 = c.reshape(-1, 6)
print("** ", c2)
**  [0 1 2 3 4 5]
**  [[0 1]
 [2 3]
 [4 5]]
**  [[0 1 2 3 4 5]]

2. view

torch中,view() 的作用相当于numpy中的reshape,重新定义矩阵的形状。

v1 = torch.range(1, 16) 
v2 = v1.view(-1, 4)  
print("** ", v1)
print("** ", v2)
**  tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
        15., 16.])
**  tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])

3. rearrange

rearrange是einops中的一个函数调用方法。

from einops import rearrange

image1 = torch.zeros(2, 224, 224, 3)
image2 = rearrange(image1, 'b w h c -> b (w h) c')
image3 = rearrange(image1, 'b w h c -> (b c) (w h)')

print("** ", image1.shape)
print("** ", image2.shape)
print("** ", image3.shape)
**  torch.Size([2, 224, 224, 3])
**  torch.Size([2, 50176, 3])
**  torch.Size([6, 50176])

4. transpose

torch.transpose(Tensor,dim0,dim1)是pytorch中的ndarray矩阵进行转置的操作。
注意: transpose()一次只能在两个维度间进行转置(也可以理解为维度转换)

x = torch.Tensor(2, 3, 4, 5)  # 这是一个4维的矩阵(只用空间位置,没有数据)
print(x.shape)
# 先转置0维和1维,之后在第2,3维间转置,之后在第1,3间转置
y = x.transpose(0, 1).transpose(3, 2).transpose(1, 3)
print(y.shape)
torch.Size([2, 3, 4, 5])
torch.Size([3, 4, 5, 2])

5. permute

注意: permute相当于可以同时操作于tensor的若干维度,transpose只能同时作用于tensor的两个维度,permute是transpose的进阶版。

print(torch.Tensor(2,3,4,5).permute(3,2,0,1).shape)
torch.Size([5, 4, 2, 3])

6. contiguous

x.is_contiguous() ——判断tensor是否连续
x.contiguous() ——把tensor变成在内存中连续分布的形式

需要变成连续分布的情况:

view只能用在contiguous的variable上。如果在view之前用了transpose, permute等,需要用contiguous()来返回一个contiguous copy。

x = torch.Tensor(5, 10)
print(x.is_contiguous())
print(x.transpose(0, 1).is_contiguous())
print(x.transpose(0, 1).contiguous().is_contiguous())
True
False
True

写代码时,一般没写contiguous()会报错提示,所以不用担心…

7. squeeze

squeeze()函数的功能是维度压缩。返回一个tensor(张量),其中 input 中大小为1的所有维都已删除。

x = torch.Tensor(2, 1, 2, 1, 2)
print(x.shape)
y = torch.squeeze(x) # 默认是把所有是1的维度都删掉
print(y.shape)
y = torch.squeeze(x, 1)
print(y.shape)
torch.Size([2, 1, 2, 1, 2])
torch.Size([2, 2, 2])
torch.Size([2, 2, 1, 2])

8. unsqueeze

squeeze()函数的功能是增加维度

x = torch.arange(0,6)
print(x.shape)
y = x.unsqueeze(0)
print(y.shape)
z = x.unsqueeze(1)
print(z.shape)
w = x.unsqueeze(2)
print(w.shape)
torch.Size([6])
torch.Size([1, 6])
torch.Size([6, 1])
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
### PyTorch 中 Tensor 的 `squeeze`、`reshape` 和 `unsqueeze` 函数详解 #### 1. `torch.squeeze()` 函数 `squeeze()` 函数用于移除张量中尺寸为 1 的维度。如果指定了某个特定的维度,则仅当该维度大小为 1 时才会被移除;如果没有指定任何维度参数,则会自动移除所有尺寸为 1 的维度。 - **语法**: ```python torch.squeeze(input, dim=None) ``` - **功能说明**: - 如果不提供 `dim` 参数,那么所有尺寸为 1 的维度都会被删除。 - 如果提供了 `dim` 参数,只有这个维度上的大小为 1 才会被删除,其他情况下不会有任何改变。 - **示例代码**: ```python import torch a = torch.tensor([[[[1], [2], [3]]]]) print("原始张量:", a) print("原始形状:", a.shape) # 输出: (1, 3, 1) b = torch.squeeze(a) print("经过 squeeze 后的张量:", b) print("经过 squeeze 后的形状:", b.shape) # 输出: (3,) c = torch.squeeze(a, dim=0) print("沿第 0 维度挤压后的张量:", c) print("沿第 0 维度挤压后的形状:", c.shape) # 输出: (3, 1) ``` --- #### 2. `torch.unsqueeze()` 函数 `unsqueeze()` 函数的作用是在指定位置增加一个新的维度,新维度的大小默认为 1。 - **语法**: ```python torch.unsqueeze(input, dim) ``` - **功能说明**: - 新增的维度由 `dim` 参数决定。 - 增加的新维度可以位于任意位置(前、中间或后)。 - **示例代码**: ```python d = torch.tensor([1, 2, 3]) print("原始张量:", d) print("原始形状:", d.shape) # 输出: (3,) e = torch.unsqueeze(d, dim=0) print("新增第 0 维度后的张量:", e) print("新增第 0 维度后的形状:", e.shape) # 输出: (1, 3) f = torch.unsqueeze(d, dim=1) print("新增第 1 维度后的张量:", f) print("新增第 1 维度后的形状:", f.shape) # 输出: (3, 1) ``` --- #### 3. `torch.reshape()` 函数 `reshape()` 函数允许重新定义张量的形状,只要满足总元素数量不变即可。 - **语法**: ```python input.reshape(shape) ``` - **功能说明**: - 可以通过 `-1` 自动推导某一个维度的大小。 - 不同于 `.view()` 方法,`.reshape()` 支持跨内存布局的操作,因此即使原张量不是连续存储也可以成功调用。 - **示例代码**: ```python g = torch.arange(6) print("原始张量:", g) print("原始形状:", g.shape) # 输出: (6,) h = g.reshape(2, 3) print("重塑后的张量:", h) print("重塑后的形状:", h.shape) # 输出: (2, 3) i = g.reshape(-1, 2) print("使用 -1 推导后的张量:", i) print("使用 -1 推导后的形状:", i.shape) # 输出: (3, 2) ``` --- ### 总结对比 | 功能 | 描述 | |--------------|----------------------------------------------------------------------------------------| | `squeeze()` | 移除尺寸为 1 的维度[^1] | | `unsqueeze()`| 在指定位置插入新的维度,其大小固定为 1[^5] | | `reshape()` | 更改张量的整体结构,前提是保持总的元素数目一致[^3] | 上述三个函数均适用于调整张量的维度,在实际应用中可以根据需求灵活组合它们来完成复杂的矩阵变换操作。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

青春是首不老歌丶

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值