相关阅读
Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482
在Pytorch中,transpose是Tensor类的一个重要方法,同时它也是torch模块中的一个函数,它们的语法如下所示。
Tensor.transpose(dim0, dim1) → Tensor
torch.transpose(input, dim0, dim1) → Tensor
input (Tensor) – the input tensor.
dim0 (int) – the first dimension to be transposed
dim1 (int) – the second dimension to be transposed
官方的解释如下:
返回一个张量,它是输入张量的转置版本,其中将给定的维度dim0和dim1交换。返回的新对象可能会变得不连续(使用is_contiguous()方法可以鉴定是否连续)。
如果输入是一个具有步幅的张量(常规稠密张量),那么输出张量将与输入张量共享底层存储,因此改变一个张量的内容将改变另一个张量的内容。
如果输入是一个稀疏张量,那么输出张量不与输入张量共享底层存储。
如果输入是压缩布局的稀疏张量(SparseCSR, SparseBSR, SparseCSC或SparseBSC),参数dim0和dim1必须同时是批处理维度,或者必须同时是稀疏维度(稀疏张量的批处理维度是稀疏维之前的维度)。
在详细说明之前,我们需要明确tensor的形状相关的概念,对于一个tensor来说,它的维度数和维度大小是两个概念,维度数是从0开始依次递增的,分别称为第0维、第1维...,每一个维度又有自己的大小。官方解释中的dim0和dim1指的是维度数。
tensor([[1, 2, 3],
[4, 5, 6]])
这个张量有2个维度,分别是0维和1维,第0维大小是2,第1维大小是3
import torch
# 创建一个张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 使用transpose操作
y = x.transpose(0, 1)
# 等价于y = x.transpose(1, 0)
print(x, y)
tensor([[1, 2, 3],
[4, 5, 6]])
tensor([[1, 4],
[2, 5],
[3, 6]])
print(id(x),id(y))
4347791776 5006922112 # 说明两个张量对象不同
print(x.storage().data_ptr(), y.storage().data_ptr())
5207345856 5207345856 # 说明两个张量对象里面保存的数据存储是共享的
y[0, 0] = 7
print(x, y)
tensor([[7, 2, 3],
[4, 5, 6]])
tensor([[7, 4],
[2, 5],
[3, 6]]) # 说明对新tensor的更改影响了原tensor
print(x.is_contiguous(), y.is_contiguous())
True False # 说明x是连续的,y不是连续的
以上的内容,类似于之前在关于python中列表的浅拷贝中说到的那样,对新列表内部嵌套的列表中的元素的更改会影响原列表。如下所示。 列表的浅拷贝
import copy
my_list = [1, 2, [1, 2]]
your_list = list(my_list) #工厂函数
his_list = my_list[:] #切片操作
her_list = copy.copy(my_list) #copy模块的copy函数
your_list[2][0] = 3
print(my_list)
print(your_list)
print(his_list)
print(her_list)
his_list[2][1] = 4
print(my_list)
print(your_list)
print(his_list)
print(her_list)
her_list[2].append(5)
print(my_list)
print(your_list)
print(his_list)
print(her_list)
输出
[1, 2, [3, 2]]
[1, 2, [3, 2]]
[1, 2, [3, 2]]
[1, 2, [3, 2]]
[1, 2, [3, 4]]
[1, 2, [3, 4]]
[1, 2, [3, 4]]
[1, 2, [3, 4]]
[1, 2, [3, 4, 5]]
[1, 2, [3, 4, 5]]
[1, 2, [3, 4, 5]]
[1, 2, [3, 4, 5]]
但不一样的是,在这里甚至对tensor中非嵌套的内容的修改,也会导致另一个tensor受到影响,如下所示。
import torch
# 创建一个张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 使用transpose操作
y = x.transpose(0, 1)
# 等价于y = x.transpose(1, 0)
print(x, y)
tensor([[1, 2, 3],
[4, 5, 6]])
tensor([[1, 4],
[2, 5],
[3, 6]])
x[0] = torch.tensor[4, 4, 4] # 改变其中一个tensor的第0个元素
print(x, y)
tensor([[4, 4, 4],
[4, 5, 6]])
tensor([[4, 4],
[4, 5],
[4, 6]])
在pytorch中,和transpose方法类似的还有方法permute,它比transpose功能更加强大,详细细节可以看下面的文章。