张量形状的重塑
首先阐述一下什么是 “视图”:视图是数据的一个别称或引用,通过该别称或引用便可访问、操作原有数据,原有数据不会产生拷贝。如果我们对视图进行修改,它会影响到原始数据,因为物理内存在同一位置,这样避免了重新创建张量的高内存开销。对张量的大部分操作就是视图操作。
与视图相对应的概念就是 “副本”:副本是一个数据的完整的拷贝,如果我们对副本进行修改,它不会影响到原始数据,因为物理内存不在同一位置。
torch 中的 view() 和 reshape() 的区别:
torch 中的 view() 与 reshape() 方法都可以用来重塑 tensor 的 shape,区别就在于使用的条件不一样。
view() 方法将 tensor 转换为指定的 shape,原始的 data 不改变,返回的新 tensor 与原始 tensor 共享存储区。新 tensor 的维度必须是原始维度的子空间,或满足连续条件,否则需要先使用 contiguous() 方法将原始 tensor 转换为满足连续条件的 tensor,再使用 view() 方法进行 shape 的变换。换句话说,view() 方法只适用于满足连续性条件的 tensor,并且该操作不会开辟新的内存空间,只是产生了对原存储空间的一个新别称或引用,返回值为视图。
reshape() 方法与 view() 方法类似,但 reshape() 方法更强大,它的返回值既可以是视图,也可以是副本,当满足连续性条件时返回视图,否则返回副本。可以认为 a.reshape() = a.view() + a.contiguous().view(),在满足 tensor 的连续性条件时,a.reshape() 返回的结果与 a.view() 相同,否则返回的结果与 a.contiguous().view() 相同。
因此当不确定能否使用 view() 方法时,可以使用 reshape()。如果只是想简单地重塑一个 tensor 的 shape,那可以直接用 reshape();但如果需要考虑内存的开销,而且要确保重塑后的 tensor 与之前的 tensor 共享存储空间,那就使用 view()。
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
print(x.view(4, 3))
print(x.reshape(2, 6))
print(torch.reshape(x, [6, -1]))
---------
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
tensor([[ 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12]])
tensor([[ 1, 2],
[ 3, 4],
[ 5, 6],
[ 7, 8],
[ 9, 10],
[11, 12]])