【Pytorch】Pytorch基础

张量的结构操作

一、创建张量

张量创建的方法和Numpy中创建array的方法十分相似。

1.1 从Python列表或者元组创建张量

a = torch.tensor([1,2,3], dtype=torch.float)
a = torch.tensor((1,2,3), dtype=torch.float)

1.2 使用arange生成张量

b = torch.arange(start=1, end=10, step=1)

1.3 使用linspace/logspace生成张量

c = torch.linspace(start=0, end=10, steps=10, requires_grad=True)
# 注意torch.linspace/logspace中的steps参数和torch.arange中的step参数的区别
c = torch.logspace(start=0, end=10, steps=10, base=10, requires_grad=False)

1.4 使用ones/zeros创建张量

d = torch.zeros((3,3))
d = torch.ones((2,3))

需要注意的是torch.zeros_liketorch.ones_like,二者可以快速生成给定tensor一样shape的0或1向量。

e = torch.zeros_like(d, dtype=torch.int)
e = torch.ones_like(d, dtype=torch.float)

1.5 创建随机张量

# torch.randint --> Returns a tensor filled with random integers generated uniformly
g = torch.randint(low=0, high=10, size=[2,2])
# 0-1均匀分布
f = torch.rand([5])
# 均匀随机分布
f = torch.randn([5])
# 正态随机分布
# mean (Tensor): the tensor of per-element means
# std (Tensor): the tensor of per-element standard deviations
f = torch.normal(mean=torch.zeros(3,3),std=torch.ones(3,3))
# 整数随机排列
# torch.randperm --> Returns a random permutation of integers from ``0`` to ``n - 1``.
f = torch.randperm(20)

1.6 创建特殊矩阵

# 单位矩阵
g = torch.eye(2,2)
# 对角矩阵
# 注意torch.diag的输入必须是一个tensor
g = torch.diag(torch.tensor([1,2,3]))

二、索引切片

张量的索引和切片与Numpy亦十分类似,切片时支持缺省函数和省略号,也可以通过索引和切片对部分元素进行修改。

# 使用省略号可以表示多个冒号
In[0]: print(a)
Out[0]: tensor([[[ 0,  1,  2],
		         [ 3,  4,  5],
		         [ 6,  7,  8]],
		
		        [[ 9, 10, 11],
		         [12, 13, 14],
		         [15, 16, 17]],
		
		        [[18, 19, 20],
		         [21, 22, 23],
		         [24, 25, 26]]])
Out[1]: print(a[...,1])
Out[1]: tensor([[ 1,  4,  7],
		        [10, 13, 16],
		        [19, 22, 25]])

对于不规则的切片提取,可以采用如torch.index_selecttorch.taketorch.gathertorch.masked_select等方法。上述这些方法可以完成提取张量的部分元素值,但不能更改张量的部分元素值得到新的张量。如果需要修改张量的部分元素得到新的张量,可以使用torch.wheretorch.index_filltorch.masked_fill;其中torch.index_fill和torch.masked_fill选取元素逻辑分别与torch.index_select和torch.masked_select相同。

2.1 torch.index_select

Pytorch: torch.index_select
该函数有三个参数:

  1. input:即被索引的张量
  2. dim:即索引的维度
  3. index:index参数属性为IntTensor或者LongTenosr,index是一个一维保存期望索引目标的序列( the 1-D tensor containing the indices to index)

2.2 torch.take

Pytorch: torch.take
t o r c h . t a k e torch.take torch.take函数首先将输入的Tensor展开为一维张量,输出一个与 i n d e x index index参数相同shape的张量;该函数有两个参数:

  1. input:输入张量
  2. index:该参数属性为LongTensor,存储我们期望索引数据的索引下标

2.3 torch.gather

Pytorch: torch.gather

2.4 torch.masked_select

Pytorch: torch.masked_select
该函数返回一个一维的张量,这个张量由输入的张量map一个为布尔张量的mask选择得到。

  1. input (Tensor) – the input tensor.
  2. mask (BoolTensor) – the tensor containing the binary mask to index with

2.5 torch.where

Pytorch: torch.where
参数:

  1. condition: 如果condition为True,返回x,否则返回y
  2. x: 从condition这个boolean张量为True的index返回x对应位置的元素。
  3. y: 元素选择逻辑与x相同

三、维度变换

Pytorch中用于维度变换的函数主要有torch.reshapetorch.squeezetorch.unsqueezetorch.transpose

3.1 torch.squeeze

Pytorch: torch.squeeze
如果张量在某个维度上只有一个元素,使用这个函数可以消除这个维度,如将 t o r c h . S i z e ( [ 1 , 2 ] ) torch.Size([1,2]) torch.Size([1,2])形状的张量变为 t o r c h . S i z e ( [ 2 ] ) torch.Size([2]) torch.Size([2])
torch.unsqueeze的作用与该函数作用效果相反。

3.2 torch.transpose

Pytorch: torch.transpose
该函数用于交换张量的维度,常用于图片存储格式的变换上。如果张量是一个二维的矩阵,通常会使用 m a t r i x . t ( ) matrix.t() matrix.t(),这个操作等价于 t o r c h . t r a n s p o s e ( m a t r i x , 0 , 1 ) torch.transpose(matrix, 0, 1) torch.transpose(matrix,0,1)
参数为:

  1. input:输入张量
  2. dim0:第一个需要被转置的维度
  3. dim1:第二个需要被转置的维度

四、合并分割

Pytorch中提供了torch.stacktorch.cat来将多个张量合并,torch.split将一个张量分割为多个张量。注意torch.stack会增加维度,而torch.cat只是连接。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值