从今天起,总结学习 pytorch 过程中遇到的一些日后可能出错的小问题。
首先就是 pytorch 官网 tutorial 第一章讲的,numpy 类型与 torch 类型共享存储,并且还给出样例:
http://pytorch.org/tutorials/beginner/blitz/tensor_tutorial.html#tensors
在文章中,作者举例,当 torch 类型转换为 numpy类型时,对其中一个操作就相当于对另一个操作:
a = torch.ones(5) print(a)
Out:
1
1
1
1
1
[torch.FloatTensor of size 5]
b = a.numpy() print(b)
Out:
[ 1. 1. 1. 1. 1.]
然后执行:
a.add_(1) print(a) print(b)
Out:
2
2
2
2
2
[torch.FloatTensor of size 5]
[ 2. 2. 2. 2. 2.]
但是,我试着将代码中的 a.add_(1) 替换为 a = a + 1,结果就不是这样的:
2
2
2
2
2
[torch.FloatTensor of size (5,)]
[1. 1. 1. 1. 1.]
可以看到,这个时候 a 变了,但是 b 并没有变。
#########################################################################
同理,反过来,当 numpy 类型转换为 torch 类型的时候,作者举例如下:
import numpy as np a = np.ones(5) b = torch.from_numpy(a) np.add(a, 1, out=a) print(a) print(b)
Out:
[ 2. 2. 2. 2. 2.]
2
2
2
2
2
[torch.DoubleTensor of size 5]
如果我把代码中的 np.add(a, 1, out=a) 替换为 a = a + 1 的话,就又不共享存储了:
[[2. 2.]
[2. 2.]
[2. 2.]]
1 1
1 1
1 1
[torch.DoubleTensor of size (3,2)]
具体为什么目前还没查到,先记在这里,日后发现为什么了再补上。