文章目录

1.torch.cat()
torch.cat() 是 PyTorch 库中的一个函数,用于沿指定维度连接张量。它接受一系列张量作为输入,并沿指定的维度进行连接。
torch.cat(tensors, dim=0, out=None)
"""
tensors:要连接的张量序列(例如,列表、元组)。
dim(可选):要沿其进行连接的维度。它指定了轴或维度编号。默认情况下,它设置为0,表示沿第一个维度进行连接。
out(可选):存储结果的输出张量。如果指定了 out,结果将存储在此张量中。如果未提供 out,则会创建一个新的张量来存储结果。
"""
import torch
# 创建两个张量
tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])
# 沿着维度0连接两个张量
result0 = torch.cat((tensor1, tensor2), dim=0)
result1 = torch.cat((tensor1, tensor2), dim=1)
print("result0",result0)
print("result1",result1)
result0 tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
result1 tensor([[1, 2, 5, 6],
[3, 4, 7, 8]])
2.torch.column_stack()
torch.column_stack() 是 PyTorch 中的一个函数,用于按列堆叠张量来创建一个新的张量。它将输入张量沿着列的方向进行堆叠,并返回一个新的张量。
torch.column_stack(tensors)
"""
tensors:要堆叠的张量序列。它可以是一个包含多个张量的元组、列表或任意可迭代对象。
"""
import torch
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])
result = torch.column_stack((tensor1, tensor2))
print(result)
tensor([[1, 4],
[2, 5],
[3, 6]])
3.torch.gather()
torch.gather() 是 PyTorch 中的一个函数,用于根据给定的索引从输入张量中收集元素。它允许你按照指定的索引从输入张量中选择元素,并将它们组合成一个新的张量。
torch.gather(input, dim, index, out=None, sparse_grad=False)
"""
input:输入张量,从中收集元素。
dim:指定索引的维度。
index:包含要收集元素的索引的张量。
out(可选):输出张量,用于存储结果。
sparse_grad(可选):指定是否启用稀疏梯度。默认为 False
"""

import torch
# 输入张量
input = torch.tensor([[1, 2], [3, 4]])
# 索引张量
index = torch.tensor([[0, 0], [1, 0]])
# 根据索引从输入张量中收集元素
result = torch.gather(input, 1, index

最低0.47元/天 解锁文章
1153

被折叠的 条评论
为什么被折叠?



