文章目录
前言
view() 函数是进行张量维度重构的函数,permute() 和 transpose() 是进行张量维度转换的函数,高阶张量由若干低阶张量构成,如结构为 (n, c, h, w)的 4 阶张量由 n 个结构为 (c, h, w) 的 3 阶张量构成,结构为 (c, h, w)的 3 阶张量由 c 个结构为 (h, w) 的 2 阶张量构成,结构为 (h, w)的 2 阶张量又由 h 个长度为 w 的 1 阶张量构成,h 为行数,w 为列数。
1. reshape()
reshape() 函数与 view() 函数都是进行维度重组的函数,使用方法类似,区别在于 view() 函数只能对张量进行操作,而 reshape() 函数既可以对张量进行操作,还可以对 numpy 数组进行操作,代码示例如下,具体原理见 view() 函数。
x = np.array([1, 2, 3, 4, 5, 6]) # 一个大小为 6 的一维 numpy 数组
y = torch.Tensor([1, 2, 3, 4, 5, 6]) # 一个大小为 6 的一阶张量
print(x.reshape(2, 3)) # 重组 x 为结构为 (2, 3) 的数组
print(y.reshape(2, 3)) # 重组 y 为结构为 (2, 3) 的张量
2. view()
① 1 阶变高阶
1 阶变 2 阶
对于一个 1 阶张量 x,进行 view(h, w) 操作就是按照索引先后顺序每次从 x 中取出 w 个元素作为作为一行数据,共取 h 次,构成一个 (h, w) 结构的 2 阶张量,具体见示例。
x = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8]) # 一个含有 8 个元素的 1 阶张量
print(x.view(4, 2)) # 返回一个 (4, 2) 结构的 2 阶张量
1 阶变 3 阶
对于一个 1 阶张量 x,进行 view(c, h, w) 操作就是按照索引先后顺序每次从 x 中取出 h*w 个元素,对这 h*w 个元素按照 1 阶张量转 2 阶数张量的方法转为一个 (h, w) 结构的 2 阶张量,共取 c 次,构成一个 (c, h, w) 结构的 3 阶张量,具体见示例。
x = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) # 一个含有 12 个元素的 1 阶张量
print(x.view(3, 2, 2)) # 返回一个 (3, 2, 2) 结构的 3 阶张量
1 阶变 4 阶
对于一个 1 阶张量 x,进行 view(n, c, h, w) 操作就是按照索引先后顺序每次从 x 中取出 c*h*w 个元素,对这 c*h*w 个元素按照 1 阶张量转 3 阶张量的方法转为一个 (c, h, w) 结构的 3 阶张量,共取 n 次,最终构成一个 (n, c, h, w) 结构的 4 阶张量,具体见示例。
# # 一个含有 24 个元素的 1 阶张量
x = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])
print(x.view(2, 2, 2, 3)) # 返回一个 (2, 2, 2, 3) 结构的 4 阶张量
1 阶变 m 阶
对于一个 1 阶张量 x,进行 view( i n i_n in, i n − 1 i_{n-1} in−1, ···, i 2 i_2 i2, i 1 i_1 i1) 操作就是按照索引先后顺序每次从 x 中取出 i n − 1 i_{n-1} in−1* i n − 2 i_{n-2} in−2*···* i 2 i_2 i2* i 1 i_1 i1 个元素,对这 i n − 1 i_{n-1} in−1* i n − 2 i_{n-2} in−2*···* i 2 i_2 i2* i 1 i_1 i1 个元素按照 1 阶张量转 m-1 阶张量的方法转为一个 ( i n − 1 i_{n-1} in−1, ···, i 2 i_2 i2, i 1 i_1 i1) 结构的 m-1 阶张量,共取 m 次,最终构成一个 ( i n i_n in, i n − 1 i_{n-1} in−1, ···, i 2 i_2 i2, i 1 i_1 i1) 结构的 m 阶张量,其中 i n i_n in 代表张量第 n 个索引的值。
② 2 阶变 m 阶
对于一个 2 阶张量 x,结构为 (h, w),要变成一个 m 阶的新张量,首先将该 2 阶张量按行展开成一个大小为 h*w 的 1 阶张量,再按照 1 阶变 m 阶的方法变为一个 m 阶张量,按行展开就是在 w 索引方向上进行拼接,2 阶张量变 3 阶张量的代码示例见下,用一个 1 阶张量来验证分析。
x = torch.Tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]]) # 一个 (4, 3) 结构的 2 阶张量
y = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) # 一个含有 12 个元素的一阶张量
print(x.view(2, 2, 3)) # 返回一个 (2, 2, 3) 结构的 3 阶张量
print(y.view(2, 2, 3)) # 返回一个 (2, 2, 3) 结构的 3 阶张量
③ 3 阶变 m 阶
对于一个 3 阶张量 x,结构为 (c, h, w),要变成一个 m 阶的新张量,首先将该 3 阶张量按行拼接得到一个结构为 (c*h, w) 的 2 阶张量,再按照 2 阶变 1 阶的方法转变为一个 1 阶张量,按行拼接就是在 h 索引方向上进行拼接,示例见图 1.1 和图 1.2。
3 阶张量变 4 阶张量的代码示例见下, 用一个拼接后得到的 2 阶张量来验证前述分析。
x = torch.Tensor([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]],
[[13, 14, 15],
[16, 17, 18]],
[[19, 20, 21],
[22, 23, 24]]]) # 一个 (4, 2, 3) 结构的 3 阶张量
y = torch.Tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12],
[13, 14, 15],
[16, 17, 18],
[19, 20, 21],
[22, 23, 24]]) # 一个 (4*2, 3) 结构的 2 阶张量
print((y.view(2, 2, 2, 3)).equal(x.view(2, 2, 2, 3))) # 两个张量转变后的结果是否相等
print(x.view(2, 2, 2, 3)) # 返回一个 (2, 2, 2, 3) 结构的 4 阶张量
④ 4 阶变 m 阶
对于一个 4 阶张量 x,结构为 (n, c, h, w),要变成一个 m 阶的新张量,首先将该 m 阶张量在 c 索引方向进行拼接得到一个结构为 (n*c, h, w) 的 3 阶张量,再按照 3 阶张量变 1 阶张量的方法转变为一个 1 阶张量,最后再按照 1阶变 m 阶的方法得到 m 阶张量,4 阶张量在 c 索引方向进行拼接的示意图如图 2.1和图 2.2 所示。
4 阶张量变 2 阶张量的代码示例见下,用一个拼接后的 3 阶张量验证前述分析。
x = torch.Tensor([[[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]],
[[[13, 14, 15],
[16, 17, 18]],
[[19, 20, 21],
[22, 23, 24]]]]) # 一个 (2, 2, 2, 3) 结构的 4 阶张量
y = torch.Tensor([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]],
[[13, 14, 15],
[16, 17, 18]],
[[19, 20, 21],
[22, 23, 24]]]) # 一个 (2*2, 2, 3) 结构的 3 阶张量
print(f'x.size() = {x.size()}') # x(2, 2, 2, 3)
print(x.view(4, 6)) # 返回结构为 (4, 6) 的 2 阶张量
print(f'y.size() = {y.size()}') # y(4, 2, 3)
print(y.view(4, 6)) # 返回结构为 (4, 6) 的 2 阶张量
3. transpose()
transpose()
函数一次进行两个维度的交换,参数是 0, 1, 2, 3, … ,随着待转换张量的阶数上升参数越来越多。
② 2 阶张量
对于一个 2 阶张量,结构为 (h, w),对应 transpose()
函数中的参数是 (0, 1) 两个索引,进行 transpose(0, 1)
操作就是在交换 h, w 两个维度,得到的结果与常见的矩阵转置相同,具体代码示例见下。
x = torch.Tensor([[1, 2],
[3, 4],
[5, 6]]) # 一个结构为 (3, 2) 的 2 阶张量
print(f'x.size() = {x.size()}') # 返回张量 x 的结构
y = x.transpose(0, 1) # 交换 h, w 两个维度
# y = x.t() # 对 x 进行转置
print(f'y.size() = {y.size()}') # 返回张量 y 的结构
print(y) # 打印交换维度后的张量 y,结构为 (2, 3)
③ 3 阶张量
对于一个 3 阶张量,结构为 (c, h, w),对应 transpose()
函数中的参数是 (0, 1, 2) 3 个索引,进行 transpose(0, 1)
操作就是在交换 c, h 两个维度,交换 c, h 两个维度的示意图见图 3.1 和图 3.2,其他维度的交换方式同理,实在不明白可以拿几本书放一起比划一下。
3 阶张量交换 c, h 维度的代码示例见下,不难发现对 3 阶张量的 c, h 两个索引进行 transpose()
操作就是以 w 索引方向为轴在进行旋转。
x = torch.Tensor([[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9], [10, 11, 12]],
[[13, 14, 15], [16, 17, 18]],
[[19, 20, 21], [22, 23, 24]]]) # 一个结构为 (4, 2, 3) 的 3 阶张量
print(f'x.size() = {x.size()}') # 返回张量 x 的结构
print(x.transpose(0, 1)) # 交换张量的 c, h 维度, 结构为 (2, 4, 3)
④ 4 阶张量
对于一个 4 阶张量,结构为 (n,c, h, w),对应 transpose() 函数中的参数是 (0, 1, 2,3) 4 个索引,对应 transpose()
的操作相对复杂一些,为方便理解这里具体分为 transpose(0, 1)
,和 transpose(0, 3)
,和 transpose(1, 2)
三种,具体原因见以下分析。
3.4.1 transpose(0, 1) 操作就是交换 n, c 两个维度,交换 n, c 两个维度的示意图见图 4.1 和图 4.2,实在不明白可以拿几本书比划一下。
4 阶张量交换 n, c 维度的代码示例见下,其他维度交换同理,不难发现对 4 阶张量而言进行 transpose(0, 1)
操作就是 n 索引方向上进行通道重新分组,如下代码中原张量 n 索引方向上有 2 组,每组有 3 个通道,交换 n, c 维度后变为 3 组,每组有 2 个通道。
x = torch.Tensor([[[[1, 2], [3, 4]],
[[5, 6], [7, 8]],
[[9, 10], [11, 12]]],
[[[13, 14], [15, 16]],
[[17, 18], [19, 20]],
[[21, 22], [23, 24]]]]) # 一个结构为 (2, 3, 2, 2) 的 4 阶张量
print(f'x.size() = {x.size()}') # 返回张量 x 的结构
print(x.transpose(0, 1)) # 返回交换 n, c 维度后的张量,结构为 (3, 2, 2, 2)
3.4.2 transpose(0, 3)
操作就是交换 n, w 两个维度,交换 n, w 两个维度的示意图比较难表示,这里用代码解释一下,transpose(0, 2)
同理,需要说明的是这种变换方式很少会用到。
对于一个结构为 (2, 2, 2, 3) 的 4 阶张量 x,进行 transpose(0, 3)
操作即将原 4 阶张量变成一个结构为 (3, 2, 2, 2) 的新 4 阶张量,可以理解为在保证原 4 阶张量中元素 c , h 索引不变的情况下的将每一个元素的 n, w 进行交换,类似于坐标系变换。
x = torch.Tensor([[[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9], [10, 11, 12]]],
[[[13, 14, 15], [16, 17, 18]],
[[19, 20, 21], [22, 23, 24]]]]) # 结构为 (2, 2, 2, 3) 的 4 阶张量
print(f'x.size() = {x.size()}') # 返回张量 x 的结构
y = x.transpose(0, 3) # 交换 n, w 维度
print(f'y.size() = {y.size()}') # 返回张量 y 的结构
print(y)
3.4.3 transpose(1, 2)
操作就是交换 c, h 两个维度,跟之前的 3 阶张量交换维度的操作一样,4 阶张量只是需要交换 n 个 3 阶张量的维度,transpose(1, 3)
和 transpose(2, 3)
同理。
4. permute()
permute()
函数一次可以进行多个维度的交换或者可以成为维度重新排列,参数是 0, 1, 2, 3, … ,随着待转换张量的阶数上升参数越来越多,本质上可以理解为多个 transpose()
操作的叠加,因此理解 permute()
函数的关键在于理解 transpose()
函数,代码示例如下。
x = torch.Tensor([[[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]],
[[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]]]) # 一个结构为 (2, 3, 4) 的 3 阶张量
print(f'x.size() = {x.size()}') # 返回张量 x 的结构
y = x.permute(2, 0, 1) # 对张量 x 进行维度重排
z = x.transpose(0, 1).transpose(0, 2) # 对张量 x 连续交换两次维度
print(y.equal(z)) # 判断张量 y 和张量 z 是否相同
print(f'z.size() = {z.size()}') # 返回张量 z 的结构
print(z)
结语
通过以上分析可以得出结论,reshpe()
和 view()
两个函数满足条件时可以根据需要设置维度,而 transpose()
和 permute()
两个函数只能在已有的维度之间进行变换,另外 transpose()
函数在 pytorch 和 numpy 中略有不同,numpy 中的 transpose()
函数相当于 pytorch 中的 permute()
函数。