reshape view作为pytorch中torch的常用操作,有一些小细节需要注意一下。
x = torch.arange(12)
print('x')
print(x)
print(id(x))
y = x.reshape(3,4)
print('y')
print(y)
print(id(y))
x[:] = 2
print('x')
print(x)
print('y')
print(y)
以上代码输出结果为
x
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
140575189363840
y
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
140575189361680
x
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
y
tensor([[2, 2, 2, 2],
[2, 2, 2, 2],
[2, 2, 2, 2]])
可以看到,x进行reshape操作后生存的y,在地址上是不一样的。但是当x的元素值修改时,y的元素值也跟着修改,这表明reshape操作可以理解为仅仅创建了一个新的指向数据的指针,并没有去开辟新的内存空间,复制新的一份数据。也就是说,reshape操作后,数据存储还是一份,但是指向該数据的指针增多一个。
同时,与reshape操作类似,view也是同种效果.
这表明,reshape操作和view操作都是创建一个新的指针变量,指向原有的数据。尽管数据的shape不一样,但是值是一样的。
那么如何复制原有数据,并且完全创建新的一份数据呢,使得两份数据之间互不干扰?
可以使用clone函数。
x = torch.arange(12)
print('x')
print(x)
print(id(x))
y = torch.clone(x).reshape(3,4)
print('y')
print(y)
print(id(y))
x[:] = 2
print('x')
print(x)
print('y')
print(y)
结果为
x
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
140575196239024
y
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
140575189994624
x
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
y
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
总结
1.使用reshape或者view操作后,数据的指针变量修改,但是数据并没有改。因此修改数据数值,会同时影响多个变量。
2.要想完全复制新的一份数据,包括开辟内存空间,赋予新的指针变量,使用clone函数。
这里有一个博客写得相当不错