(1)本文涉及函数的列表(注释在代码中)
torch.cat 连接张量,和stack相似
torch.chunk 分块
torch.gather 聚合 图解PyTorch中的torch.gather函数 - 知乎
torch.index_select 切片
torch.masked_select 根据mask二元值,返回1维张量
torch.nonzero 返回非零元素的索引
torch.split 切割成相同的快,最后一个块可以不同
torch.squeeze 挤压维度
torch.stack 连接张量序列,和cat相似
orch.t 转置0,1维
torch.transpose 交换维度
torch.unbind 移除指定维度,返回元组(各个切片)
torch.unsqueeze 扩展维度
(2)代码示例(含注释)
"""
索引,切片,连接,换位 Indexing, Slicing, Joining, Mutating Ops
"""
import torch
# # 在给定维度上对输入的张量序列 seq 进行连接操作。
x = torch.randn(2, 3)
y = torch.zeros(2, 3)
z = torch.randn(2, 3)
obj1 = torch.cat((x, y, z), dim=0) # 沿y轴方向连接
obj2 = torch.cat((x, y, z), dim=1) # 沿x轴方向连接
# # 在给定维度(轴)上将输入张量进行分块儿。
obj3 = torch.chunk(torch.randn(3, 6), chunks=3, dim=0)
obj4 = torch.chunk(torch.randn(3, 6), chunks=3, dim=1)
# # 沿给定轴 dim,将输入索引张量 index 指定位置的值进行聚合。index (LongTensor) – 聚合元素的下标
# https://zhuanlan.zhihu.com/p/352877584
x = torch.Tensor([[3, 4, 5], [6, 7, 8], [9, 10, 11]])
# dims=0: [0, 0] [2, 1] [1, 2] 左侧是021 右侧是012
# dims=1: [0, 0] [0, 2] [0 ,1] 右侧是021 左侧是000
index = torch.LongTensor([[0, 2, 1]])
obj5 = torch.gather(x, 0, index)
obj6 = torch.gather(x, 1, index)
# # 沿着指定维度对输入进行切片,取 index 中指定的相应项(index 为一个 LongTensor),
# # 然后返回到一个新的张量, 返回的张量与原始张量_Tensor_有相同的维度(在指定轴上)。
# # 注意: 返回的张量不与原始张量共享内存空间。
x = torch.randn(3, 4)
# 沿y轴方向取第0行和第2行
obj7 = torch.index_select(input=x, dim=0, index=torch.LongTensor([0, 2]))
# 沿x轴方向取第0行和第2行
obj8 = torch.index_select(input=x, dim=1, index=torch.LongTensor([0, 2]))
# # 根据掩码张量 mask 中的二元值,取输入张量中的指定项( mask 为一个 ByteTensor),将取值返回到一个新的 1维 张量,
# # 张量 mask 须跟 input 张量有相同数量的元素数目,但形状或维度不需要相同。
# # 注意:返回的 1维 张量不与原始张量共享内存空间。
# mask (ByteTensor) – 掩码张量,包含了二元索引值
x = torch.randn(3, 4)
mask = x.ge(0.5) # 大于0.5返回True
obj9 = torch.masked_select(input=x, mask=mask)
# # 返回一个包含输入 input 中非零元素索引的张量。输出张量中的每行包含输入中非零元素的索引。
# # 如果输入 input 有 n 维,则输出的索引张量 output 的形状为 z x n,
# # 这里 z 是输入张量 input 中所有非零元素的个数。
obj10 = torch.nonzero(torch.Tensor([1, 1, 0, 0, 1]))
obj11 = torch.nonzero(torch.eye(3, 3))
# # 将输入张量分割成相等形状的 chunks(如果可分)。
# # 如果沿指定维的张量形状大小不能被 split_size 整分, 则最后一个分块会小于其它分块。
obj12 = torch.split(tensor=torch.randn(3, 4), split_size_or_sections=2, dim=1)
# # 挤压:将输入张量形状中的 1 去除并返回。(默认沿x、y两个方向)
x = torch.zeros([2, 1, 3, 4, 1]) # 五维 torch.Size([2, 1, 3, 4, 1])
obj13 = torch.squeeze(x) # 返回三维 torch.Size([2, 3, 4])
# torch.squeeze(x, 0) # torch.Size([2, 1, 3, 4, 1])
# torch.squeeze(x, 1) # torch.Size([2, 3, 4, 1])
# # 沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
# # 对比 torch.cat() 输入二维,返回二维
# # torch.stack() 输入二维,返回三维
x = torch.randn(2, 3)
y = torch.zeros(2, 3)
z = torch.randn(2, 3)
obj14 = torch.stack((x, y, z), dim=0) # 沿y轴方向拼接
obj15 = torch.stack((x, y, z), dim=1) # 沿x轴方向拼接
# # 输入一个矩阵(2 维张量),并转置 0, 1 维。
# # 可以被视为函数 transpose(input, 0, 1) 的简写函数。
obj16 = torch.t(torch.tensor([[1, 2], [4, 6]]))
# # 返回输入矩阵 input 的转置。交换维度 dim0 和 dim1。 输出张量与输入张量共享内存,
# # 所以改变其中一个会导致另外一个也被修改。
obj17 = torch.transpose(torch.tensor([[1, 2], [4, 6]]), 0, 1)
# # 移除指定维后,返回一个元组,包含了沿着指定维切片后的各个切片
obj18 = torch.unbind(torch.tensor([[1, 2], [3, 4]]), dim=0)
# # 返回一个新的张量,对输入的制定位置插入维度 1
# # 注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
# # 如果 dim 为负,则将会被转化 dim+input.dim()+1
obj19 = torch.unsqueeze(torch.tensor([1, 2, 3]), dim=0) # 沿x轴方向扩张维度 1
obj20 = torch.unsqueeze(torch.tensor([1, 2, 3]), dim=1) # 沿y轴方向扩张维度 1
print("*"*20, "obj1", "*"*20, "\n", obj1, "\n")
print("*"*20, "obj2", "*"*20, "\n", obj2, "\n")
print("*"*20, "obj3", "*"*20, "\n", obj3, "\n")
print("*"*20, "obj4", "*"*20, "\n", obj4, "\n")
print("*"*20, "obj5", "*"*20, "\n", obj5, "\n")
print("*"*20, "obj6", "*"*20, "\n", obj6, "\n")
print("*"*20, "obj7", "*"*20, "\n", x, "\n", obj7, "\n", obj8, "\n")
print("*"*20, "obj9", "*"*20, "\n", x, "\n", mask, "\n", obj9, "\n")
print("*"*20, "obj10", "*"*20, "\n", obj10, "\n", obj11, "\n")
print("*"*20, "obj12", "*"*20, "\n", obj12, "\n")
print("*"*20, "obj13", "*"*20, "\n", x, "\n", x.size(), "\n", obj13, "\n", obj13.size(), "\n")
print("*"*20, "obj14", "*"*20, "\n", obj14, "\n", obj15, "\n")
print("*"*20, "obj16", "*"*20, "\n", obj16, "\n")
print("*"*20, "obj17", "*"*20, "\n", obj17, "\n")
print("*"*20, "obj18", "*"*20, "\n", obj18, "\n")
print("*"*20, "obj19", "*"*20, "\n", obj19, "\n", obj20, "\n")
>>>output
******************** obj1 ********************
tensor([[-1.6186, -0.8067, -1.1804],
[-0.1467, 0.6239, 0.9534],
[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000],
[ 0.9263, -0.8269, 0.0046],
[-1.1006, -0.6988, 0.1890]])******************** obj2 ********************
tensor([[-1.6186, -0.8067, -1.1804, 0.0000, 0.0000, 0.0000, 0.9263, -0.8269,
0.0046],
[-0.1467, 0.6239, 0.9534, 0.0000, 0.0000, 0.0000, -1.1006, -0.6988,
0.1890]])******************** obj3 ********************
(tensor([[ 1.4899, -0.3238, -0.7243, -0.5336, -0.6189, -0.0321]]),tensor([[ 0.8191, -0.4457, 1.5956, -0.2469, -0.1813, -0.1472]]),
tensor([[-0.4933, 1.4287, -0.1502, -0.8178, -2.3286, 1.0828]]))
******************** obj4 ********************
(tensor([[-0.8373, -1.3918],
[-0.5960, -0.6540],
[-1.4265, -0.4198]]), tensor([[-0.9599, -1.1114],
[ 1.2388, -1.6912],
[ 0.4254, 0.3522]]), tensor([[-0.2481, 0.9961],
[ 1.1239, -0.0241],
[-0.4623, -1.2694]]))******************** obj5 ********************
tensor([[ 3., 10., 8.]])******************** obj6 ********************
tensor([[3., 5., 4.]])******************** obj7 ********************
tensor([[ 0.6158, -0.1513, 0.7025],
[-0.6720, 0.8464, -1.6161]])
tensor([[ 0.3243, -0.9802, 0.8814, 1.0558],
[-0.2766, 2.1381, 0.3016, 0.0050]])
tensor([[ 0.3243, 0.8814],
[-0.9072, 0.7438],
[-0.2766, 0.3016]])******************** obj9 ********************
tensor([[ 0.6158, -0.1513, 0.7025],
[-0.6720, 0.8464, -1.6161]])
tensor([[False, False, False, True],
[False, False, True, False],
[False, False, True, True]])
tensor([1.3356, 1.6224, 0.8899, 0.8873])******************** obj10 ********************
tensor([[0],
[1],
[4]])
tensor([[0, 0],
[1, 1],
[2, 2]])******************** obj12 ********************
(tensor([[-1.6856, -1.5886],
[-0.7732, 2.1042],
[ 0.4190, -0.0502]]), tensor([[-0.6008, 0.5552],
[-0.4270, 0.3902],
[-2.3373, -0.8840]]))******************** obj13 ********************
tensor([[ 0.6158, -0.1513, 0.7025],
[-0.6720, 0.8464, -1.6161]])
torch.Size([2, 3])
tensor([[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]])
torch.Size([2, 3, 4])******************** obj14 ********************
tensor([[[ 0.6158, -0.1513, 0.7025],
[-0.6720, 0.8464, -1.6161]],[[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000]],[[ 0.9513, -0.0260, -0.0575],
[ 1.5535, 0.5139, 0.0120]]])
tensor([[[ 0.6158, -0.1513, 0.7025],
[ 0.0000, 0.0000, 0.0000],
[ 0.9513, -0.0260, -0.0575]],[[-0.6720, 0.8464, -1.6161],
[ 0.0000, 0.0000, 0.0000],
[ 1.5535, 0.5139, 0.0120]]])******************** obj16 ********************
tensor([[1, 4],
[2, 6]])******************** obj17 ********************
tensor([[1, 4],
[2, 6]])******************** obj18 ********************
(tensor([1, 2]), tensor([3, 4]))******************** obj19 ********************
tensor([[1, 2, 3]])
tensor([[1],
[2],
[3]])
>>>如有疑问,欢迎评论区一起探讨