对pytorch1.7.1的API进行一次完全的解读
讲在前面
针对pytorch 1.7.1
进行一次全方位的解读,所以此博客会持续更新相当的一段时间。随后的博客会跟进torch的版本更新信息,其中能添加的实例代码我都会添加。
API
TORCH
这个包含了多维张量的数据结构和一些对于张量的数学操作。另外,提供了对张量和其他数据类型进行有效序列化的方法,还有些其他的方法。
1.Tensor(对张量本身进行的一些操作)
pytorch对于tensor的存储是一分为二的,一个创建好的张量,其头信息区叫作Tensor
,负责存储tensor的形状、步长、数据类型等,其内部数据保存为数组,存储在存储区Srorage
中。
所以,Tensor
占用的内存较小,真正占用内存的地方是Storage
。
1.1 is_tensor
- 函数源码解释:
def is_tensor(obj):
r"""如果对象为PyTorch的张量,则返回True.
Note that this function is simply doing ``isinstance(obj, Tensor)``.
Using that ``isinstance`` check is better for typechecking with mypy,
and more explicit - so it's recommended to use that instead of
``is_tensor``.
参数:
输入参数为要测试的对象
"""
return isinstance(obj, torch.Tensor)
- 代码测试:
import torch
import numpy as np
a = np.ones([1, 2])
t = torch.from_numpy(a)
print(torch.is_tensor(a))
print(torch.is_tensor(t))
print(t)
-----------
False
True
tensor([[1., 1.]], dtype=torch.float64)
1.2 is_storage
- 函数源码解释:
def is_storage(obj):
r"""如果对象是PyTorch storage的话.则返回True
参数:
测试对象
"""
return type(obj) in _storage_classes
- 代码测试:
import torch
a = torch.rand(3, 5)
print(a)
print(a.storage())
print(type(a.storage()))
print(a.storage()[0])
x = a.storage()
print(torch.is_storage(x))
----------
tensor([[0.8754, 0.7499, 0.3954, 0.3763, 0.8578],
[0.7598, 0.9750, 0.6963, 0.1602, 0.3442],
[0.9445, 0.5585, 0.2020, 0.9230, 0.4736]])
0.8753897547721863
0.7498539090156555
0.3954319953918457
0.3762640357017517
0.857833981513977
0.7598322033882141
0.9749904870986938
0.6962557435035706
0.16019368171691895
0.34423667192459106
0.9444983005523682
0.5584731101989746
0.20202845335006714
0.9230412840843201
0.473596453666687
[torch.FloatStorage of size 15]
<class 'torch.FloatStorage'>
0.8753897547721863
True
1.3 is_complex
- 函数源码解释:
def is_complex(input: Tensor) -> _bool: ...
r"""
判断输入张量是否为复数类型(torch.complex64或者torch.complex128),是则返回True。
"""
- 代码测试:
import torch
a = torch.tensor([1,