Pytorch索引、切片、连接


在这里插入图片描述


1.torch.cat()

  torch.cat() 是 PyTorch 库中的一个函数,用于沿指定维度连接张量。它接受一系列张量作为输入,并沿指定的维度进行连接。

torch.cat(tensors, dim=0, out=None)
"""
tensors:要连接的张量序列(例如,列表、元组)。
dim(可选):要沿其进行连接的维度。它指定了轴或维度编号。默认情况下,它设置为0,表示沿第一个维度进行连接。
out(可选):存储结果的输出张量。如果指定了 out,结果将存储在此张量中。如果未提供 out,则会创建一个新的张量来存储结果。
"""
import torch

# 创建两个张量
tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])

# 沿着维度0连接两个张量
result0 = torch.cat((tensor1, tensor2), dim=0)
result1 = torch.cat((tensor1, tensor2), dim=1)

print("result0",result0)
print("result1",result1)
result0 tensor([[1, 2],
        		[3, 4],
        		[5, 6],
        		[7, 8]])
result1 tensor([[1, 2, 5, 6],
        		[3, 4, 7, 8]])

2.torch.column_stack()

 torch.column_stack() 是 PyTorch 中的一个函数,用于按列堆叠张量来创建一个新的张量。它将输入张量沿着列的方向进行堆叠,并返回一个新的张量。

torch.column_stack(tensors)
"""
tensors:要堆叠的张量序列。它可以是一个包含多个张量的元组、列表或任意可迭代对象。
"""
import torch

tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])

result = torch.column_stack((tensor1, tensor2))

print(result)
tensor([[1, 4],
        [2, 5],
        [3, 6]])

3.torch.gather()

torch.gather() 是 PyTorch 中的一个函数,用于根据给定的索引从输入张量中收集元素。它允许你按照指定的索引从输入张量中选择元素,并将它们组合成一个新的张量。

torch.gather(input, dim, index, out=None, sparse_grad=False)
"""
input:输入张量,从中收集元素。
dim:指定索引的维度。
index:包含要收集元素的索引的张量。
out(可选):输出张量,用于存储结果。
sparse_grad(可选):指定是否启用稀疏梯度。默认为 False
"""

在这里插入图片描述

import torch

# 输入张量
input = torch.tensor([[1, 2], [3, 4]])

# 索引张量
index = torch.tensor([[0, 0], [1, 0]])

# 根据索引从输入张量中收集元素
result = torch.gather(input, 1, index)

print(result)
#tensor([[1, 2],
#       [3, 2]])
import torch

# 输入张量
input = torch.tensor([[1, 2], [3, 4]])

# 索引张量
index = torch.tensor([[0, 0], [1, 0]])

# 根据索引从输入张量中收集元素
result = torch.gather(input, 0, index)

print(result)

4.torch.hstack()

  torch.hstack() 是 PyTorch 中的一个函数,用于沿着水平方向(列维度)堆叠张量来创建一个新的张量。它将输入张量沿着水平方向进行堆叠,并返回一个新的张量。

torch.hstack(tensors) -> Tensor
"""
tensors:要堆叠的张量序列。可以是一个包含多个张量的元组、列表或任意可迭代对象。
"""
import torch

tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])

result = torch.hstack((tensor1, tensor2))

print(result)
# tensor([[1, 2, 5, 6],
#        [3, 4, 7, 8]])

5.torch.vstack()

torch.vstack()是PyTorch中用于沿垂直方向(行维度)堆叠张量的函数。它将输入张量沿垂直方向进行堆叠,并返回一个新的张量。

torch.vstack(tensors) -> Tensor
import torch

tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])

result = torch.vstack((tensor1, tensor2))

print(result)
tensor([[1, 2],
        [3, 4],
        [5, 6],
        [7, 8]])

6.torch.index_select()

torch.index_select() 是 PyTorch 中的一个函数,用于按索引从输入张量中选择元素并返回一个新的张量。

torch.index_select(input, dim, index, out=None) -> Tensor
"""
input:输入张量,从中选择元素。
dim:指定索引的维度。即要在 input 张量的哪个维度上进行索引。
index:指定要选择的索引的张量。它的形状可以与 input 张量的形状不同,但必须满足广播规则。
out(可选):输出张量,用于存储结果。如果提供了 out,则结果将存储在此张量中。
"""
import torch

# 输入张量
input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 索引张量
index = torch.tensor([0, 2])

# 根据索引从输入张量中选择元素
result = torch.index_select(input, 0, index)

print(result)
tensor([[1, 2, 3],
        [7, 8, 9]])

7.torch.masked_select()

torch.masked_select() 是 PyTorch 中的一个函数,用于根据给定的掩码从输入张量中选择元素并返回一个新的张量。

torch.masked_select(input, mask, out=None) -> Tensor
"""
input:输入张量,从中选择元素。
mask:掩码张量,用于指定要选择的元素。mask 张量的形状必须与 input 张量的形状相同,或者满足广播规则。
out(可选):输出张量,用于存储结果。如果提供了 out,则结果将存储在此张量中。
"""
import torch

# 输入张量
input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 掩码张量
mask = torch.tensor([[True, False, True], [False, True, False], [True, False, True]])

# 根据掩码从输入张量中选择元素
result = torch.masked_select(input, mask)

print(result)
tensor([1, 3, 5, 7, 9])

8.torch.reshape

torch.reshape() 是 PyTorch 中的一个函数,用于改变张量的形状而不改变元素的数量。它返回一个具有新形状的新张量,其中的元素与原始张量相同。

torch.reshape(input, shape) -> Tensor
"""
input:输入张量,要改变形状的张量。
shape:指定的新形状。可以是一个整数元组或传递一个张量,其中包含新的形状。
torch.reshape() 函数将输入张量重新排列为指定的新形状。新的形状应该满足以下条件:

1. 新形状的元素数量与原始张量的元素数量相同。
2. 新形状中各维度的乘积与原始张量的元素数量相同。
"""
import torch

# 输入张量
input = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 改变形状为 (3, 2)
result1 = torch.reshape(input, (3, 2))

# 改变形状为 (1, 6)
result2 = torch.reshape(input, (1, 6))

# 改变形状为 (6,)
result3 = torch.reshape(input, (6,))

print(result1)
print(result2)
print(result3)

9.torch.stack()

torch.stack() 是 PyTorch 中的一个函数,用于沿着新的维度对给定的张量序列进行堆叠操作。

torch.stack(tensors, dim=0, *, out=None) -> Tensor
"""
tensors:张量的序列,要进行堆叠操作的张量。
dim(可选):指定新的维度的位置。默认值为 0。
out(可选):输出张量。如果提供了输出张量,则将结果存储在该张量中。
"""
import torch

# 张量序列
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])
tensor3 = torch.tensor([7, 8, 9])

# 在维度 0 上进行堆叠操作
result = torch.stack([tensor1, tensor2, tensor3], dim=0)

print(result)
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

torch.stack和torch.cat异同点
1.维度变化:

  • torch.stack 会在指定位置插入一个新的维度,从而增加张量的总维度数。
  • torch.cat 则不会增加新的维度,只是在指定的现有维度上进行连接。

2.输入要求:

  • torch.stack 要求所有输入张量的形状完全相同。
  • torch.cat 只要求输入张量在要连接的维度之外的其他维度形状相同。

10.torch.where()

torch.where() 是 PyTorch 中的一个函数,用于根据给定的条件从两个张量中选择元素。

torch.where(condition, x, y) -> Tensor
"""
condition:条件张量,一个布尔张量,用于指定元素选择的条件。
x:张量,与 condition 形状相同的张量,当对应位置的 condition 元素为 True 时,选择 x 中的对应元素。
y:张量,与 condition 形状相同的张量,当对应位置的 condition 元素为 False 时,选择 y 中的对应元素。
"""
import torch

# 条件张量
condition = torch.tensor([[True, False], [False, True]])

# 选择的张量 x
x = torch.tensor([[1, 2], [3, 4]])

# 选择的张量 y
y = torch.tensor([[5, 6], [7, 8]])

# 根据条件选择元素
result = torch.where(condition, x, y)

print(result)
#tensor([[1, 6],
#       [7, 4]])
import torch

# 输入张量
input = torch.tensor([1.5, 0.8, -1.2, 2.7, -3.5])

# 阈值
threshold = 0

# 根据阈值选择元素
result = torch.where(input > threshold, torch.tensor(1), torch.tensor(0))

print(result)#tensor([1, 1, 0, 1, 0])

11.torch.tile()

torch.tile() 是 PyTorch 中的一个函数,用于在指定维度上重复张量的元素。

torch.tile(input, reps) -> Tensor
"""
input:输入张量,要重复的张量。
reps:重复的次数,可以是一个整数或一个元组。
"""
import torch

# 输入张量
input = torch.tensor([1, 2, 3])

# 在维度 0 上重复 2 次
result = torch.tile(input, 2)

print(result)#tensor([1, 2, 3, 1, 2, 3])
import torch

# 输入张量
input = torch.tensor([[1, 2], [3, 4]])

# 在维度 0 和维度 1 上重复
result = torch.tile(input, (2, 3))

print(result)
tensor([[1, 2, 1, 2, 1, 2],
        [3, 4, 3, 4, 3, 4],
        [1, 2, 1, 2, 1, 2],
        [3, 4, 3, 4, 3, 4]])

12.torch.take()

torch.take() 是 PyTorch 中的一个函数,用于在给定索引处提取张量的元素。

torch.take(input, indices) -> Tensor
"""
input:输入张量,要从中提取元素的张量。
indices:索引张量,包含要提取的元素的索引。它可以是一个一维整数张量或一个具有相同形状的张量。
"""
import torch

# 输入张量
input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 索引张量
indices = torch.tensor([1, 4, 7])

# 提取元素
result = torch.take(input, indices)

print(result)# tensor([2, 5, 8])
import torch

# 输入张量
input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 索引张量
indices = torch.tensor([[0, 2], [1, 2]])

# 提取部分元素
result = torch.take(input, indices)

print(result)
tensor([[1, 3],
        [2, 3]])

13.torch.scatter()

torch.scatter() 是 PyTorch 中的一个函数,用于根据索引在张量中进行散射操作。散射操作是指根据给定的索引,将源张量的值散布(写入)到目标张量的指定位置。

在这里插入图片描述

torch.scatter(input, dim, index, src)
"""
input:输入张量,表示目标张量,散射操作将在此张量上进行。
dim:整数值,表示散射操作沿着的维度。
index:索引张量,指定散射操作的目标位置。
src:源张量,包含要散射到目标张量中的值。
"""
import torch

# 创建目标张量
target = torch.zeros(3, 4)

# 创建索引张量和源张量
index = torch.tensor([[0, 1, 2, 0], [2, 1, 0, 2]])
source = torch.tensor([1, 2, 3, 4])

# 执行散射操作
torch.scatter(target, dim=1, index=index, src=source)

print(target)
# 输出:
# tensor([[1., 4., 3., 1.],
#         [0., 3., 2., 0.],
#         [3., 2., 1., 3.]])
  • 10
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: PyTorch中的tensor切片是指从一个tensor中选择特定的元素或子集。切片操作可以通过索引或范围来指定。下面是关于PyTorch tensor切片的一些重要信息: 1.基本切片操作:您可以使用索引操作符[]来对tensor进行切片。例如,如果有一个3x3的tensor,可以使用`tensor[1:3, 0:2]`来获得第二行和第三行的前两列。 2.索引规则:切片操作的索引是从0开始的。在切片时,起始索引是包含在切片中的,而结束索引是不包含在切片中的。例如,`tensor[1:3]`将返回索引为1和2的元素,但不包括索引为3的元素。 3.负数索引:您可以使用负数索引来从后面开始对tensor进行切片。例如,`tensor[-1]`将返回最后一个元素。 4.步长操作:您可以使用步长操作来跳过某些元素进行切片。例如,`tensor[0:3:2]`将返回索引为0和2的元素。 5.高维tensor切片:对于高维tensor,您可以在多个维度上进行切片。例如,`tensor[:, 1]`将返回所有行的第二列。 6.更改切片切片的结果是原始tensor的视图,并且共享相同的内存。因此,对切片的更改将反映在原始tensor上。 7.使用切片进行赋值:您可以使用切片操作来对tensor的某些元素进行赋值。例如,`tensor[1:3, 0:2] = 0`将第二行和第三行的前两列设置为0。 请注意,这只是关于PyTorch tensor切片的一些基本信息,更复杂的操作如高级索引和掩码索引等也是可行的。 ### 回答2: PyTorch中的tensor切片是指从一个tensor中选择部分元素的操作。通过切片操作,我们可以访问或修改tensor中的特定元素,或者创建一个新的tensor来存储所选元素。 切片操作的基本语法是t[start:stop:step],其中start表示起始位置,stop表示结束位置(但不包括该位置上的元素),step表示步长。 例如,如果有一个1维tensor t = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],我们可以使用切片操作来选择其中的一部分元素。 - t[2:6]将返回一个新的tensor,包含元素2, 3, 4, 5; - t[:5]将返回一个新的tensor,包含元素0, 1, 2, 3, 4; - t[5:]将返回一个新的tensor,包含元素5, 6, 7, 8, 9; - t[1:8:2]将返回一个新的tensor,包含元素1, 3, 5, 7。 对于多维tensor,我们可以使用相同的切片操作来选择各个维度上的元素。 例如,如果有一个2维tensor t = [[0, 1, 2], [3, 4, 5], [6, 7, 8]],我们可以使用切片操作来选择其中的一部分元素。 - t[1:3, :2]将返回一个新的tensor,包含元素[[3, 4], [6, 7]],表示选择第1行和第2行的前2列; - t[:, 1]将返回一个新的tensor,包含元素[1, 4, 7],表示选择所有行的第1列。 需要注意的是,切片操作返回的是原始tensor的一个视图,而不是创建一个新的tensor。这意味着对切片后的tensor进行修改,将会影响到原始tensor。如果需要创建一个新的tensor对象,可以使用切片操作的clone()方法来复制原始tensor的数据。 ### 回答3: PyTorch是一个常用的深度学习框架,Tensor是PyTorch中用于处理数据的基本数据结构。在PyTorch中,我们可以使用Tensor进行切片操作来选择或修改我们需要的元素。 通过索引操作,我们可以对Tensor进行切片。在切片操作中,可以使用逗号分隔的索引列表来选择多个维度的元素。例如,使用tensor[a:b, c:d]的切片操作,可以选择Tensor中从第a行到第b行(不包括b)以及第c列到第d列(不包括d)的元素。 在切片操作中,索引的开始和结束位置都是可选的,如果不指定,则默认为从开头到末尾。此外,还可以使用负数索引来表示从末尾开始的位置。 除了使用切片进行选择之外,我们还可以使用切片进行修改。通过将切片操作放在赋值语句的左侧,我们可以将新的值赋予切片所选择的元素。 值得注意的是,切片操作返回的是原始Tensor的视图,而不是复制。这意味着对切片的修改也会反映在原始Tensor上。 需要注意的是,在PyTorch中进行切片操作不会对Tensor进行内存复制,这样可以减少内存消耗并提高代码的执行效率。 总而言之,PyTorch中的Tensor切片操作允许我们根据需要选择或修改Tensor中的元素。通过索引切片操作,我们可以根据具体需求灵活操作Tensor的数据。这为我们在深度学习任务中提供了丰富的选择和便利性。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值