在pytorch中,tensor的实际数据以一维数组(storage)的形式存储于某个连续的内存中,以”行优先"进行存储,tensor连续(contiquous)是指tensor的storage元素排列顺序与其按行优先时的元素排列顺序相同,tensor不连续会导致某些操作无法进行,比如view()就无法进行。
import torch
a = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
print(a)
print(a.storage())# a是连续的print(a.is_contiguous())
b = a.t()# b是a的转置
print(b)
print(b.storage())
print(b.is_contiguous())
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
1
2
3
4
5
6
7
8
9
[torch.LongStorage of size 9]
tensor([[1, 4, 7],
[2, 5, 8],
[3, 6, 9]])
1
2
3
4
5
6
7
8
9
[torch.LongStorage of size 9]
False
b应该是1 4 7 2 5 8 3 6 9才是连续的,虽然它经过a的转置,但是并没有改变其原先的存储顺序,所以出现的结果就是不连续的。
使用.contiguous()实现其为行存储 :
c = b.contiguous()
print(b)
print(c)
print(b.storage())
print(c.storage())
print(c.is_contiguous())
tensor([[1, 4, 7],
[2, 5, 8],
[3, 6, 9]])
tensor([[1, 4, 7],
[2, 5, 8],
[3, 6, 9]])
1
2
3
4
5
6
7
8
9
[torch.LongStorage of size 9]
1
4
7
2
5
8
3
6
9
[torch.LongStorage of size 9]
True
可以发现c已经和b的存储顺序不一样了,是连续的了。
参考:28-masked_fill张量掩码|Dropout正则化|view|register_buffer|contiguous|多头注意力-transformer_哔哩哔哩_bilibili