Pytorch索引、切片、连接


在这里插入图片描述


1.torch.cat()

  torch.cat() 是 PyTorch 库中的一个函数,用于沿指定维度连接张量。它接受一系列张量作为输入,并沿指定的维度进行连接。

torch.cat(tensors, dim=0, out=None)
"""
tensors:要连接的张量序列(例如,列表、元组)。
dim(可选):要沿其进行连接的维度。它指定了轴或维度编号。默认情况下,它设置为0,表示沿第一个维度进行连接。
out(可选):存储结果的输出张量。如果指定了 out,结果将存储在此张量中。如果未提供 out,则会创建一个新的张量来存储结果。
"""
import torch

# 创建两个张量
tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])

# 沿着维度0连接两个张量
result0 = torch.cat((tensor1, tensor2), dim=0)
result1 = torch.cat((tensor1, tensor2), dim=1)

print("result0",result0)
print("result1",result1)
result0 tensor([[1, 2],
        		[3, 4],
        		[5, 6],
        		[7, 8]])
result1 tensor([[1, 2, 5, 6],
        		[3, 4, 7, 8]])

2.torch.column_stack()

 torch.column_stack() 是 PyTorch 中的一个函数,用于按列堆叠张量来创建一个新的张量。它将输入张量沿着列的方向进行堆叠,并返回一个新的张量。

torch.column_stack(tensors)
"""
tensors:要堆叠的张量序列。它可以是一个包含多个张量的元组、列表或任意可迭代对象。
"""
import torch

tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])

result = torch.column_stack((tensor1, tensor2))

print(result)
tensor([[1, 4],
        [2, 5],
        [3, 6]])

3.torch.gather()

torch.gather() 是 PyTorch 中的一个函数,用于根据给定的索引从输入张量中收集元素。它允许你按照指定的索引从输入张量中选择元素,并将它们组合成一个新的张量。

torch.gather(input, dim, index, out=None, sparse_grad=False)
"""
input:输入张量,从中收集元素。
dim:指定索引的维度。
index:包含要收集元素的索引的张量。
out(可选):输出张量,用于存储结果。
sparse_grad(可选):指定是否启用稀疏梯度。默认为 False
"""

在这里插入图片描述

import torch

# 输入张量
input = torch.tensor([[1, 2], [3, 4]])

# 索引张量
index = torch.tensor([[0, 0], [1, 0]])

# 根据索引从输入张量中收集元素
result = torch.gather(input, 1, index
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值