一、torch.cat 是什么?
torch.cat 是 PyTorch 中的一个函数
,用于沿着某个维度连接张量。
torch.cat 接受一个张量列表
,并沿着某个维度连接它们。这个函数会返回一个新的张量,其中包含了所有输入张量的元素。
特别是在处理序列数据(如文本或音频)时,需要将多个小的张量连接起来形成一个大的张量。例如,在自然语言处理中,需要将多个词向量连接起来形成一个句子的向量表示
。
二、使用步骤
import torch
# 创建两个大小为 (2, 3) 的张量
tensor1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
tensor2 = torch.tensor([[7, 8, 9], [10, 11, 12]])
# 沿着第一个维度连接这两个张量
result = torch.cat((tensor1, tensor2), dim=0)
print(result)
输出
:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]])
总结
在这个例子中,我们沿着第一个维度(也就是行
)连接了两个张量。结果是一个大小为 (4, 3) 的张量,其中包含了原来两个张量的所有元素。