即看即用 && 索引,切片,连接,换位 (Indexing, Slicing, Joining, Mutating Ops) && Pytorch官方文档总结 && 笔记 (二)

(1)本文涉及函数的列表(注释在代码中)

  1. torch.cat                                    连接张量,和stack相似

  2. torch.chunk                               分块

  3. torch.gather                               聚合 图解PyTorch中的torch.gather函数 - 知乎

  4. torch.index_select                    切片

  5. torch.masked_select                根据mask二元值,返回1维张量

  6. torch.nonzero                            返回非零元素的索引

  7. torch.split                                  切割成相同的快,最后一个块可以不同

  8. torch.squeeze                           挤压维度

  9. torch.stack                                连接张量序列,和cat相似

  10. orch.t                                         转置0,1维

  11. torch.transpose                        交换维度

  12. torch.unbind                             移除指定维度,返回元组(各个切片)

  13. torch.unsqueeze                       扩展维度

(2)代码示例(含注释)

"""
索引,切片,连接,换位 Indexing, Slicing, Joining, Mutating Ops
"""
import torch

# # 在给定维度上对输入的张量序列 seq 进行连接操作。
x = torch.randn(2, 3)
y = torch.zeros(2, 3)
z = torch.randn(2, 3)
obj1 = torch.cat((x, y, z), dim=0)  # 沿y轴方向连接
obj2 = torch.cat((x, y, z), dim=1)  # 沿x轴方向连接

# # 在给定维度(轴)上将输入张量进行分块儿。
obj3 = torch.chunk(torch.randn(3, 6), chunks=3, dim=0)
obj4 = torch.chunk(torch.randn(3, 6), chunks=3, dim=1)

# # 沿给定轴 dim,将输入索引张量 index 指定位置的值进行聚合。index (LongTensor) – 聚合元素的下标
# https://zhuanlan.zhihu.com/p/352877584
x = torch.Tensor([[3, 4, 5], [6, 7, 8], [9, 10, 11]])
# dims=0: [0, 0] [2, 1] [1, 2]  左侧是021 右侧是012
# dims=1: [0, 0] [0, 2] [0 ,1]  右侧是021 左侧是000
index = torch.LongTensor([[0, 2, 1]])
obj5 = torch.gather(x, 0, index)
obj6 = torch.gather(x, 1, index)

# # 沿着指定维度对输入进行切片,取 index 中指定的相应项(index 为一个 LongTensor),
# # 然后返回到一个新的张量, 返回的张量与原始张量_Tensor_有相同的维度(在指定轴上)。
# # 注意: 返回的张量不与原始张量共享内存空间。
x = torch.randn(3, 4)
# 沿y轴方向取第0行和第2行
obj7 = torch.index_select(input=x, dim=0, index=torch.LongTensor([0, 2]))
# 沿x轴方向取第0行和第2行
obj8 = torch.index_select(input=x, dim=1, index=torch.LongTensor([0, 2]))

# # 根据掩码张量 mask 中的二元值,取输入张量中的指定项( mask 为一个 ByteTensor),将取值返回到一个新的 1维 张量,
# # 张量 mask 须跟 input 张量有相同数量的元素数目,但形状或维度不需要相同。
# # 注意:返回的 1维 张量不与原始张量共享内存空间。
# mask (ByteTensor) – 掩码张量,包含了二元索引值
x = torch.randn(3, 4)
mask = x.ge(0.5)  # 大于0.5返回True
obj9 = torch.masked_select(input=x, mask=mask)

# # 返回一个包含输入 input 中非零元素索引的张量。输出张量中的每行包含输入中非零元素的索引。
# # 如果输入 input 有 n 维,则输出的索引张量 output 的形状为 z x n,
# # 这里 z 是输入张量 input 中所有非零元素的个数。
obj10 = torch.nonzero(torch.Tensor([1, 1, 0, 0, 1]))
obj11 = torch.nonzero(torch.eye(3, 3))

# # 将输入张量分割成相等形状的 chunks(如果可分)。
# # 如果沿指定维的张量形状大小不能被 split_size 整分, 则最后一个分块会小于其它分块。
obj12 = torch.split(tensor=torch.randn(3, 4), split_size_or_sections=2, dim=1)

# # 挤压:将输入张量形状中的 1 去除并返回。(默认沿x、y两个方向)
x = torch.zeros([2, 1, 3, 4, 1])  # 五维 torch.Size([2, 1, 3, 4, 1])
obj13 = torch.squeeze(x)  # 返回三维 torch.Size([2, 3, 4])
# torch.squeeze(x, 0) # torch.Size([2, 1, 3, 4, 1])
# torch.squeeze(x, 1) # torch.Size([2, 3, 4, 1])

# # 沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
# # 对比 torch.cat()  输入二维,返回二维
# # torch.stack()  输入二维,返回三维
x = torch.randn(2, 3)
y = torch.zeros(2, 3)
z = torch.randn(2, 3)
obj14 = torch.stack((x, y, z), dim=0)  # 沿y轴方向拼接
obj15 = torch.stack((x, y, z), dim=1)  # 沿x轴方向拼接

# # 输入一个矩阵(2 维张量),并转置 0, 1 维。
# # 可以被视为函数 transpose(input, 0, 1) 的简写函数。
obj16 = torch.t(torch.tensor([[1, 2], [4, 6]]))

# # 返回输入矩阵 input 的转置。交换维度 dim0 和 dim1。 输出张量与输入张量共享内存,
# # 所以改变其中一个会导致另外一个也被修改。
obj17 = torch.transpose(torch.tensor([[1, 2], [4, 6]]), 0, 1)

# # 移除指定维后,返回一个元组,包含了沿着指定维切片后的各个切片
obj18 = torch.unbind(torch.tensor([[1, 2], [3, 4]]), dim=0)

# # 返回一个新的张量,对输入的制定位置插入维度 1
# # 注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
# # 如果 dim 为负,则将会被转化 dim+input.dim()+1
obj19 = torch.unsqueeze(torch.tensor([1, 2, 3]), dim=0)  # 沿x轴方向扩张维度 1
obj20 = torch.unsqueeze(torch.tensor([1, 2, 3]), dim=1)  # 沿y轴方向扩张维度 1

print("*"*20, "obj1", "*"*20, "\n", obj1, "\n")
print("*"*20, "obj2", "*"*20, "\n", obj2, "\n")
print("*"*20, "obj3", "*"*20, "\n", obj3, "\n")
print("*"*20, "obj4", "*"*20, "\n", obj4, "\n")
print("*"*20, "obj5", "*"*20, "\n", obj5, "\n")
print("*"*20, "obj6", "*"*20, "\n", obj6, "\n")
print("*"*20, "obj7", "*"*20, "\n", x, "\n", obj7, "\n", obj8, "\n")
print("*"*20, "obj9", "*"*20, "\n", x, "\n", mask, "\n", obj9, "\n")
print("*"*20, "obj10", "*"*20, "\n", obj10, "\n", obj11, "\n")
print("*"*20, "obj12", "*"*20, "\n", obj12, "\n")
print("*"*20, "obj13", "*"*20, "\n", x, "\n", x.size(), "\n", obj13, "\n", obj13.size(), "\n")
print("*"*20, "obj14", "*"*20, "\n", obj14, "\n", obj15, "\n")
print("*"*20, "obj16", "*"*20, "\n", obj16, "\n")
print("*"*20, "obj17", "*"*20, "\n", obj17, "\n")
print("*"*20, "obj18", "*"*20, "\n", obj18, "\n")
print("*"*20, "obj19", "*"*20, "\n", obj19, "\n", obj20, "\n")

>>>output

******************** obj1 ******************** 
 tensor([[-1.6186, -0.8067, -1.1804],
              [-0.1467,  0.6239,  0.9534],
              [ 0.0000,  0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000],
              [ 0.9263, -0.8269,  0.0046],
              [-1.1006, -0.6988,  0.1890]]) 

******************** obj2 ******************** 
 tensor([[-1.6186, -0.8067, -1.1804,  0.0000,  0.0000,  0.0000,  0.9263, -0.8269,
                0.0046],
              [-0.1467,  0.6239,  0.9534,  0.0000,  0.0000,  0.0000, -1.1006, -0.6988,
                0.1890]]) 

******************** obj3 ******************** 
 (tensor([[ 1.4899, -0.3238, -0.7243, -0.5336, -0.6189, -0.0321]]),

  tensor([[ 0.8191,  -0.4457,  1.5956, -0.2469, -0.1813, -0.1472]]),

  tensor([[-0.4933,  1.4287, -0.1502, -0.8178, -2.3286,  1.0828]])) 

******************** obj4 ******************** 
 (tensor([[-0.8373, -1.3918],
              [-0.5960, -0.6540],
              [-1.4265, -0.4198]]), tensor([[-0.9599, -1.1114],
              [ 1.2388, -1.6912],
              [ 0.4254,  0.3522]]), tensor([[-0.2481,  0.9961],
              [ 1.1239, -0.0241],
              [-0.4623, -1.2694]])) 

******************** obj5 ******************** 
 tensor([[ 3., 10.,  8.]]) 

******************** obj6 ******************** 
 tensor([[3., 5., 4.]]) 

******************** obj7 ******************** 
 tensor([[ 0.6158, -0.1513,  0.7025],
              [-0.6720,  0.8464, -1.6161]]) 
 tensor([[ 0.3243, -0.9802,  0.8814,  1.0558],
              [-0.2766,  2.1381,  0.3016,  0.0050]]) 
 tensor([[ 0.3243,  0.8814],
              [-0.9072,  0.7438],
              [-0.2766,  0.3016]]) 

******************** obj9 ******************** 
 tensor([[ 0.6158, -0.1513,  0.7025],
              [-0.6720,  0.8464, -1.6161]]) 
 tensor([[False, False, False,  True],
              [False, False,  True, False],
              [False, False,  True,  True]]) 
 tensor([1.3356, 1.6224, 0.8899, 0.8873]) 

******************** obj10 ******************** 
 tensor([[0],
              [1],
              [4]]) 
 tensor([[0, 0],
              [1, 1],
             [2, 2]]) 

******************** obj12 ******************** 
 (tensor([[-1.6856, -1.5886],
             [-0.7732,  2.1042],
             [ 0.4190, -0.0502]]), tensor([[-0.6008,  0.5552],
             [-0.4270,  0.3902],
             [-2.3373, -0.8840]])) 

******************** obj13 ******************** 
 tensor([[ 0.6158, -0.1513,  0.7025],
              [-0.6720,  0.8464, -1.6161]]) 
 torch.Size([2, 3]) 
 tensor([[[0., 0., 0., 0.],
               [0., 0., 0., 0.],
               [0., 0., 0., 0.]],

              [[0., 0., 0., 0.],
               [0., 0., 0., 0.],
               [0., 0., 0., 0.]]]) 
 torch.Size([2, 3, 4]) 

******************** obj14 ******************** 
 tensor([[[ 0.6158, -0.1513,  0.7025],
                [-0.6720,  0.8464, -1.6161]],

               [[ 0.0000,  0.0000,  0.0000],
                [ 0.0000,  0.0000,  0.0000]],

               [[ 0.9513, -0.0260, -0.0575],
                [ 1.5535,  0.5139,  0.0120]]]) 
 tensor([[[ 0.6158, -0.1513,  0.7025],
                [ 0.0000,  0.0000,  0.0000],
                [ 0.9513, -0.0260, -0.0575]],

               [[-0.6720,  0.8464, -1.6161],
                [ 0.0000,  0.0000,  0.0000],
                [ 1.5535,  0.5139,  0.0120]]]) 

******************** obj16 ******************** 
 tensor([[1, 4],
              [2, 6]]) 

******************** obj17 ******************** 
 tensor([[1, 4],
              [2, 6]]) 

******************** obj18 ******************** 
 (tensor([1, 2]), tensor([3, 4])) 

******************** obj19 ******************** 
 tensor([[1, 2, 3]]) 
 tensor([[1],
              [2],
              [3]]) 

 >>>如有疑问,欢迎评论区一起探讨

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Flying Bulldog

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值