PyTorch框架学习三——张量操作

一、拼接

1.torch.cat()

功能:将tensor按照维度dim进行拼接,除了需要拼接的维度外,其余维度尺寸得是相同的。

torch.cat(tensors, dim=0, out=None)

看一下所有的参数:
在这里插入图片描述

  1. tensors:需要被拼接的张量序列。
  2. dim:(int,可选)被拼接的维度,默认为0。
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,
         -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,
         -0.5790,  0.1497]])

2.torch.stack()

功能:在新创建的维度dim上进行拼接,所有的张量必须是相同的维度。

torch.stack(tensors, dim=0, out=None)

在这里插入图片描述
注意:stack()会创建一个新的维度。

t = torch.ones((2, 3))
t_stack = torch.stack([t, t, t], dim=2)
print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))

在这里插入图片描述
原来t的维度是(2, 3),本来是没有第三维的,但是stack()会构建新的dim=2,就是先构建第三维dim=2,然后在该维度上进行拼接。

二、切分

1.torch.chunk()

功能:将tensor按维度dim进行平均切分。如果不能整除,最后一份tensor在该维度上的长度小于其他tensor。

torch.chunk(input, chunks, dim=0)

在这里插入图片描述

  1. input:要切分的张量。
  2. chunks:要切分的份数。
  3. dim:要切分的维度,默认为0。
a = torch.ones((2, 7))  # 7
list_of_tensors = torch.chunk(a, dim=1, chunks=3)   # 3

for idx, t in enumerate(list_of_tensors):
    print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))

在这里插入图片描述

2.torch.split()

功能:将tensor按dim进行切分。

torch.split(tensor, split_size_or_sections, dim=0)

在这里插入图片描述

  1. tensor:要切分的张量。
  2. split_size_or_sections:(int或list(int))为int时,表示每一份的长度,如果不能整除,最后一份的长度要小于其他的张量,为list时,按list元素来切分。
  3. dim:同上。
>>> a = torch.arange(10).reshape(5,2)
>>> a
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
         [2, 3]]),
 tensor([[4, 5],
         [6, 7]]),
 tensor([[8, 9]]))
>>> torch.split(a, [1,4])
(tensor([[0, 1]]),
 tensor([[2, 3],
         [4, 5],
         [6, 7],
         [8, 9]]))

三、索引

1.torch.index_select()

功能:在dim上,按照index索引数据,返回一个依据index索引数据拼接的张量。

torch.index_select(input, dim, index, out=None)

在这里插入图片描述

  1. input:要索引的张量。
  2. dim:被索引的维度。
  3. index:一维张量,包括了要索引的数据序号。(long,不能是float)
  4. out:输出张量(可选)。
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-0.4664,  0.2647, -0.1228, -1.1068],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
        [-0.4664, -0.1228],
        [-1.1734,  0.7230]])

2.torch.masked_select()

功能:按照mask中的True进行索引,返回一个一维张量。

torch.masked_select(input, mask, out=None)

在这里插入图片描述

>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.3552, -2.3825, -0.8297,  0.3477],
        [-1.2035,  1.2252,  0.5002,  0.6248],
        [ 0.1307, -2.0608,  0.1244,  2.0139]])
>>> mask = x.ge(0.5)
>>> mask
tensor([[False, False, False, False],
        [False, True, True, True],
        [False, False, False, True]])
>>> torch.masked_select(x, mask)
tensor([ 1.2252,  0.5002,  0.6248,  2.0139])

四、变换

1.torch.reshape()

功能:变换张量的形状。

torch.reshape(input, shape)

在这里插入图片描述

  1. input:输入张量。
  2. shape:新张量的形状。当某个维度为-1时,表示该维度不用关心,可以从别的维度计算得到。
>>> a = torch.arange(4.)
>>> torch.reshape(a, (2, 2))
tensor([[ 0.,  1.],
        [ 2.,  3.]])
>>> b = torch.tensor([[0, 1], [2, 3]])
>>> torch.reshape(b, (-1,))
tensor([ 0,  1,  2,  3])

2.torch.transpace()

功能:交换tensor的两个维度。

torch.transpose(input, dim0, dim1)

在这里插入图片描述

  1. input:输入张量。
  2. dim0和dim1:要交换的两个维度。
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 1.0028, -0.9893,  0.5809],
        [-0.1669,  0.7299,  0.4942]])
>>> torch.transpose(x, 0, 1)
tensor([[ 1.0028, -0.1669],
        [-0.9893,  0.7299],
        [ 0.5809,  0.4942]])

3.torch.t()

功能:2维tensor转置,对矩阵而言。等价于torch.transpose(input, 0, 1)。

torch.t(input)
>>> x = torch.randn(())
>>> x
tensor(0.1995)
>>> torch.t(x)
tensor(0.1995)
>>> x = torch.randn(3)
>>> x
tensor([ 2.4320, -0.4608,  0.7702])
>>> torch.t(x)
tensor([ 2.4320, -0.4608,  0.7702])
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.4875,  0.9158, -0.5872],
        [ 0.3938, -0.6929,  0.6932]])
>>> torch.t(x)
tensor([[ 0.4875,  0.3938],
        [ 0.9158, -0.6929],
        [-0.5872,  0.6932]])

注意:只对矩阵会转置,对标量和向量都不会。

4.torch.squeeze()

功能:压缩长度为1的维度(轴)。

torch.squeeze(input, dim=None, out=None)

在这里插入图片描述

  1. input:输入张量。
  2. dim:(可选)若为None,移除所有长度为1的轴,若指定轴,当且仅当该轴长度为1时移除。
  3. out:输出张量。
>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])

5.torch.unsqueeze()

功能:返回一个新的张量,对输入的指定位置插入维度 1。

torch.unsqueeze(input, dim)
>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1,  2,  3,  4]])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],
        [ 2],
        [ 3],
        [ 4]])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值