2.5 Tensor的API

欢迎订阅本专栏:《PyTorch深度学习实践》
订阅地址:https://blog.csdn.net/sinat_33761963/category_9720080.html

  • 第二章:认识Tensor的类型、创建、存储、api等,打好Tensor的基础,是进行PyTorch深度学习实践的重中之重的基础。
  • 第三章:学习PyTorch如何读入各种外部数据
  • 第四章:利用PyTorch从头到尾创建、训练、评估一个模型,理解与熟悉PyTorch实现模型的每个步骤,用到的模块与方法。
  • 第五章:学习如何利用PyTorch提供的3种方法去创建各种模型结构。
  • 第六章:利用PyTorch实现简单与经典的模型全过程:简单二分类、手写字体识别、词向量的实现、自编码器实现。
  • 第七章利用PyTorch实现复杂模型:翻译机(nlp领域)、生成对抗网络(GAN)、强化学习(RL)、风格迁移(cv领域)。
  • 第八章:PyTorch的其他高级用法:模型在不同框架之间的迁移、可视化、多个GPU并行计算。

Pytorch提供了很多tensor的API, 如若逐条去记忆想必自然不是程序员的学习方法,更不是AI从业者的学习方法,API是帮助我们快速开发的工具,绝不能成为学习的压力,因此相对有效的方法是“类”–> “查” -->“熟”。

  • “类”:对纷繁众多的API进行归类,先归大类,再在大类中归小类,对知识形成一个系统的认知;

  • “查”:当在实际的应用中忽然想实现某个功能的时候,在“类”的帮助下马上知道了是什么类别下的API能帮助实现,于是去官方文档中迅速查找到函数名、使用方法及例子;

  • “熟”:这是自然而然的过程,在不断高效“查”的过程中对“类”也逐渐熟悉,于是根本无需死记硬背,就对海量的API熟悉了。

pytorch tensor API的官方文档地址:http://pytorch.org/docs

文档中给API归了大类,但是每个大类罗列的API是根据首字母的顺序,也不方便大家快速认知,因此本章的小节中不但按照官方的思路对大类进行讲解,也根据笔者的思路对大类再做了更细致的归类,方便大家对API的认知。

2.5.1 调用API的方式

方式一:操作后创建新的变量

方式一中也有两种不同的方式可以选用

  • (1)使用Torch module中的方法

用上文提到过的转置方法来举例, 可以调用torch.transpose(), 将原tensor:a作为参数传入, 生成新的tensor:a_t。

import torch

a = torch.ones(3, 2)
a_t = torch.transpose(a, 0, 1)
  • (2)使用tensor object的中方法

直接在tensor:a上调用transpose方法,结果同上,

a = torch.ones(3, 2)
a_t = a.transpose(0, 1)

因此两种方法可以互相替换使用,但注意目前(2)中tensor对象的方法数目相对较少,用(1)torch moduel中的功能更全面。

方式二:在原变量上做操作

对tensor调用带有下划线结尾的方法,即为对原tensor进行操作,无需赋值新的变量

a = torch.ones(3, 2)
a.zero_()
tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])

2.5.2 索引、切块

(1)简单索引与切块

对tensor的索引和对python中list的索引是一样的,只是tensor是多维数组,list是一维的。对list的索引不再赘述,下面看几个tensor索引的例子。

points = torch.tensor([[1.0, 4.0],[2.0, 1.0],[3.0, 5.0]])
points[1:]  # 获取第2行开始的所有行 [[2.0, 1.0],[3.0, 5.0]]
points[1:, :]  # 获取第2行开始的所有行的所有列,结果与上同 [[2.0, 1.0],[3.0, 5.0]]
points[1:, 0]  # 获取第2行开始的所有行的第0列, [[2.0],[3.0]]
tensor([2., 3.])
(2)高级索引与切块

a.通过索引切割tensor: torch.index_select() & torch.take()

a = torch.rand(4,6)
b = torch.index_select(input=a, dim=1, index=torch.LongTensor([0,1,2,3]))

print(b)
tensor([[0.5042, 0.0299, 0.0558, 0.9661],
        [0.0617, 0.8948, 0.4300, 0.0555],
        [0.0531, 0.1290, 0.4504, 0.7697],
        [0.4639, 0.8179, 0.8393, 0.6148]])

对a, 在维度1上,切出索引在[0,3)之间的元素。
注意:index必须是LongTensor类型的。

torch.take()也是通过索引取数,只是将Input平铺成一维的:

b = torch.take(input=a, index=torch.LongTensor([0,1,2,3]))
print(b)
tensor([0.5042, 0.0299, 0.0558, 0.9661])
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值