在pytorch中,tensor的实际数据以一维数组(storage)的形式存储于某个连续的内存中,以“行优先”进行存储。
tensor的连续性
tensor连续(contiguous)是指tensor的storage元素排列顺序与其按行优先时的元素排列顺序相同。如下图所示:
上图中,tensor b是tensor a经过转置而来的,即使用了 tensor.t() 方法。
出现不连续现象,本质上是由于pytorch中不同tensor可能共用同一个storage导致的。
pytorch的很多操作都会导致tensor不连续,如tensor.transpose()(tensor.t())、tensor.narrow()、tensor.expand()。
以转置为例,因为转置操作前后共用同一个storage,但显然转置后的tensor按照行优先排列成1维后与原storage不同了,因此转置后结果属于不连续(见下例)。
2. tensor.is_contiguous()
is_contiguous
直观的解释是Tensor底层一维数组元素的存储顺序与Tensor按行优先一维展开的元素顺序是否一致。
如果想要变得连续使用contiguous
方法,如果Tensor不是连续的,则会重新开辟一块内存空间保证数据是在内存中是连续的;如果Tensor是连续的,则contiguous
无操作。
tensor.is_contiguous()用于判断tensor是否连续,以转置为例说明:
import torch
a = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
print(a)
print(a.storage())
print(a.is_contiguous()) # a是连续的
"""
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
1
2
3
4
5
6
7
8
9
[torch.LongStorage of size 9]
True
"""
b = a.t() # b是a的转置
print(b)
print(b.storage())
print(b.is_contiguous()) # b是不连续的
"""
tensor([[1, 4, 7],
[2, 5, 8],
[3, 6, 9]])
1
2
3
4
5
6
7
8
9
[torch.LongStorage of size 9]
False
"""
3. tensor不连续的后果
tensor不连续会导致某些操作无法进行,比如view()就无法进行。在上面的例子中:由于 b 是不连续的,所以对其进行view()操作会报错;b.view(3,3)没报错,因为b本身的shape就是(3,3)。
print(b.view(3, 3))
"""
tensor([[1, 4, 7],
[2, 5, 8],
[3, 6, 9]])
"""
print(b.view(1, 9))# 报错
print(b.view(-1))# 报错
4. tensor.contiguous()
tensor.contiguous()返回一个与原始tensor有相同元素的 “连续”tensor,如果原始tensor本身就是连续的,则返回原始tensor。
注意:tensor.contiguous()函数不会对原始数据做任何修改,他不仅返回一个新tensor,还为这个新tensor创建了一个新的storage,在这个storage上,该新的tensor是连续的。
继续使用上面的例子:
c = b.contiguous()
print(b)
print(c)
print(b.storage())
print(c.storage())
输出结果:
# b
tensor([[1, 4, 7],
[2, 5, 8],
[3, 6, 9]])
# c
tensor([[1, 4, 7],
[2, 5, 8],
[3, 6, 9]])
# b.storage
1
2
3
4
5
6
7
8
9
[torch.LongStorage of size 9]
#c.storage
1
4
7
2
5
8
3
6
9
[torch.LongStorage of size 9]
接着运行如下代码:
print(b.is_contiguous()) # False
print(c.is_contiguous()) # True
print(c.view(1, 9)) # tensor([[1, 4, 7, 2, 5, 8, 3, 6, 9]])
参考自:https://blog.csdn.net/baidu_41774120/article/details/128666944