Tensor
与 NumPy
有很高的相似性,彼此之间的互操作也非常简单有效,需要注意的是 Tensor
与 NumPy
共享内存,由于 NumPy
历史悠久,所以遇到 Tensor
不支持的操作时,可以先转换成 NumPy
,处理后再转换成 Tensor
,转换开销很小。
1. Tensor 转化为 NumPy
In [1]: import torch as t
In [2]: a = t.ones(5)
In [3]: a
Out[3]: tensor([1., 1., 1., 1., 1.])
In [4]: b = a.numpy()
In [5]: b
Out[5]: array([1., 1., 1., 1., 1.], dtype=float32)
2. NumPy 转化为 Tensor
In [6]: import numpy as np
In [7]: a = np.ones(5)
In [8]: a
Out[8]: array([1., 1., 1., 1., 1.])
In [9]: b =t.from_numpy(a)
In [10]: b
Out[10]: tensor([1., 1., 1., 1., 1.], dtype=torch.float64)
值得注意的是,Torch
中的 Tensor
和 NumPy
中的 Array
共享内存位置,一个改变,另一个也同样改变。注意使用的是 b.add_()
。
In [10]: b
Out[10]: tensor([1., 1., 1., 1., 1.], dtype=torch.float64)
In [11]: b.add_(1)
Out[11]: tensor([2., 2., 2., 2., 2.], dtype=torch.float64)
In [12]: a
Out[12]: array([2., 2., 2., 2., 2.])
In [13]: b
Out[13]: tensor([2., 2., 2., 2., 2.], dtype=torch.float64)
下面看使用 b.add()
发现 b 并没有改变。
In [13]: b
Out[13]: tensor([2., 2., 2., 2., 2.], dtype=torch.float64)
In [14]: b.add(2)
Out[14]: tensor([4., 4., 4., 4., 4.], dtype=torch.float64)
In [15]: a
Out[15]: array([2., 2., 2., 2., 2.])
In [16]: b
Out[16]: tensor([2., 2., 2., 2., 2.], dtype=torch.float64)
b.add_()
和 b.add()
的区别:
任何操作符都固定地在前面加上 _
来表示替换。例如:y.copy_(x)
,y.t_()
,都将改变 y
。
3. PyTorch 广播法则
当输入数组的某个维度的长度为 1 时,计算时沿此维度复制扩充成一样的形状。
可以通过以下两个函数组合手动实现广播法则:
unsqueeze
或者view
: 为数据的某一维的形状补 1,实现法则 1expand
或者expand_as
,重复数组,实现法则 3;该操作不会复制数组,所以不会占用额外空间
注意: repeat
实现和 expand
相类似的功能,但是 repeat
会把形同的数据复制多份,因此会占用额外的空间。
3.1 自动广播法则
In [17]: a = t.ones(3,2)
In [18]: a
Out[18]:
tensor([[1., 1.],
[1., 1.],
[1., 1.]])
In [19]: b = t.zeros(2,3,1)
In [20]: b
Out[20]:
tensor([[[0.],
[0.],
[0.]],
[[0.],
[0.],
[0.]]])
可以看到 a
是二维的,而 b
是三维的,但是可以通过广播法则直接进行相加计算。
In [23]: a + b
Out[23]:
tensor([[[1., 1.],
[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.],
[1., 1.]]])
In [24]:
3.2 手动广播法则
In [24]: a.unsqueeze(0).expand(2,3,2) + b.expand(2,3,2)
Out[24]:
tensor([[[1., 1.],
[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.],
[1., 1.]]])
In [25]:
4. Numpy 广播法则
- 让所有输入数组都向其中
shape
最长的数组看齐,shape
中不足的部分可通过在前面加 1 补齐; - 两个数组要么在某一维度的长度一致,要么其中一个为 1,否则不能计算;