1.连续的概念
1、contiguous连续的两种方法
Tensor底层一维数组元素的存储顺序与Tensor按行优先一维展开的元素顺序是否一致
Tensor多维数组底层实现是使用一块连续内存的1维数组(行优先顺序存储,下文描述),Tensor在元信息里保存了多维数组的形状,在访问元素时,通过多维度索引转化成1维数组相对于数组起始位置的偏移量即可找到对应的数据。
如果想要变得连续使用contiguous方法,如果Tensor不是连续的,则会重新开辟一块内存空间保证数据是在内存中是连续的,如果Tensor是连续的,则contiguous无操作。
is_contiguous():用来判断是否连续
contiguous():不是连续的,则会重新开辟一块内存空间保证数据是在内存中是连续的,如果Tensor是连续的,则contiguous无操作。
2、行优先
>>> t = torch.arange(12).reshape(3,4)
>>> t
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
数组 t 在内存中实际以一维数组形式存储,通过 flatten 方法查看 t 的一维展开形式,实际存储形式与一维展开一致,如图2,
>>> t.flatten()
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
3、stride偏移量
这里列出对应列优先的矩阵(transpose(0,1))
这个时候按行访问:就需要每次01234偏移4个才能访问到,因为底层存储形式是123456789
也就是我不管你到底表面变成什么形状,我底层存储的都是0123456789,保存的是最开始初始化的矩阵按行展开的形式,如果你想改变形状,好,这里改变stride来迎合你!
这里就是stride的更改!
>>>t = torch.arange(12).reshape(3,4)
>>>t
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>>t.stride()
(4, 1)
>>>t2 = t.transpose(0,1)
>>>t2
tensor([[ 0, 4, 8],
[ 1, 5, 9],
[ 2, 6, 10],
[ 3, 7, 11]])
>>>t2.stride()
(1, 4)
>>>t.data_ptr() == t2.data_ptr() # 底层数据是同一个一维数组
True
>>>t.is_contiguous(),t2.is_contiguous() # t连续,t2不连续
(True, False)
2.为什么view前要变连续
1.view在底层操作
上文的t和t2都使用的是同一份数据
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
都是通过stride来改变的,这里view无法考虑之前进行的transpose等操作,view默认直接对这个底层存储进行操作。此刻就算不报错也是对
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]进行的操作。
2.contiguous()的作用
>>>t3 = t2.contiguous()
>>>t3
tensor([[ 0, 4, 8],
[ 1, 5, 9],
[ 2, 6, 10],
[ 3, 7, 11]])
>>>t3.data_ptr() == t2.data_ptr() # 底层数据不是同一个一维数组
False
此时将底层的信息给改变了,这时候的t3和t2是完全不同的tensor了,view操作也就能接上上文的transpose了,tensor([ 0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11])。
2.reshape更加方便
为了解决用户使用便捷性问题,PyTorch在0.4版本以后提供了reshape方法,实现了类似于 tensor.contigous().view(*args)的功能,如果不关心底层数据是否使用了新的内存,则使用reshape方法更方便。
原文:
https://zhuanlan.zhihu.com/p/64551412