欢迎订阅本专栏:《PyTorch深度学习实践》
订阅地址:https://blog.csdn.net/sinat_33761963/category_9720080.html
- 第二章:认识Tensor的类型、创建、存储、api等,打好Tensor的基础,是进行PyTorch深度学习实践的重中之重的基础。
- 第三章:学习PyTorch如何读入各种外部数据
- 第四章:利用PyTorch从头到尾创建、训练、评估一个模型,理解与熟悉PyTorch实现模型的每个步骤,用到的模块与方法。
- 第五章:学习如何利用PyTorch提供的3种方法去创建各种模型结构。
- 第六章:利用PyTorch实现简单与经典的模型全过程:简单二分类、手写字体识别、词向量的实现、自编码器实现。
- 第七章利用PyTorch实现复杂模型:翻译机(nlp领域)、生成对抗网络(GAN)、强化学习(RL)、风格迁移(cv领域)。
- 第八章:PyTorch的其他高级用法:模型在不同框架之间的迁移、可视化、多个GPU并行计算。
Pytorch提供了很多tensor的API, 如若逐条去记忆想必自然不是程序员的学习方法,更不是AI从业者的学习方法,API是帮助我们快速开发的工具,绝不能成为学习的压力,因此相对有效的方法是“类”–> “查” -->“熟”。
-
“类”:对纷繁众多的API进行归类,先归大类,再在大类中归小类,对知识形成一个系统的认知;
-
“查”:当在实际的应用中忽然想实现某个功能的时候,在“类”的帮助下马上知道了是什么类别下的API能帮助实现,于是去官方文档中迅速查找到函数名、使用方法及例子;
-
“熟”:这是自然而然的过程,在不断高效“查”的过程中对“类”也逐渐熟悉,于是根本无需死记硬背,就对海量的API熟悉了。
pytorch tensor API的官方文档地址:http://pytorch.org/docs
文档中给API归了大类,但是每个大类罗列的API是根据首字母的顺序,也不方便大家快速认知,因此本章的小节中不但按照官方的思路对大类进行讲解,也根据笔者的思路对大类再做了更细致的归类,方便大家对API的认知。
2.5.1 调用API的方式
方式一:操作后创建新的变量
方式一中也有两种不同的方式可以选用
- (1)使用Torch module中的方法
用上文提到过的转置方法来举例, 可以调用torch.transpose(), 将原tensor:a作为参数传入, 生成新的tensor:a_t。
import torch
a = torch.ones(3, 2)
a_t = torch.transpose(a, 0, 1)
- (2)使用tensor object的中方法
直接在tensor:a上调用transpose方法,结果同上,
a = torch.ones(3, 2)
a_t = a.transpose(0, 1)
因此两种方法可以互相替换使用,但注意目前(2)中tensor对象的方法数目相对较少,用(1)torch moduel中的功能更全面。
方式二:在原变量上做操作
对tensor调用带有下划线结尾的方法,即为对原tensor进行操作,无需赋值新的变量
a = torch.ones(3, 2)
a.zero_()
tensor([[0., 0.],
[0., 0.],
[0., 0.]])
2.5.2 索引、切块
(1)简单索引与切块
对tensor的索引和对python中list的索引是一样的,只是tensor是多维数组,list是一维的。对list的索引不再赘述,下面看几个tensor索引的例子。
points = torch.tensor([[1.0, 4.0],[2.0, 1.0],[3.0, 5.0]])
points[1:] # 获取第2行开始的所有行 [[2.0, 1.0],[3.0, 5.0]]
points[1:, :] # 获取第2行开始的所有行的所有列,结果与上同 [[2.0, 1.0],[3.0, 5.0]]
points[1:, 0] # 获取第2行开始的所有行的第0列, [[2.0],[3.0]]
tensor([2., 3.])
(2)高级索引与切块
a.通过索引切割tensor: torch.index_select() & torch.take()
a = torch.rand(4,6)
b = torch.index_select(input=a, dim=1, index=torch.LongTensor([0,1,2,3]))
print(b)
tensor([[0.5042, 0.0299, 0.0558, 0.9661],
[0.0617, 0.8948, 0.4300, 0.0555],
[0.0531, 0.1290, 0.4504, 0.7697],
[0.4639, 0.8179, 0.8393, 0.6148]])
对a, 在维度1上,切出索引在[0,3)之间的元素。
注意:index必须是LongTensor类型的。
torch.take()也是通过索引取数,只是将Input平铺成一维的:
b = torch.take(input=a, index=torch.LongTensor([0,1,2,3]))
print(b)
tensor([0.5042, 0.0299, 0.0558, 0.9661])