torch.cat(dim=) 是 PyTorch 中用于在指定维度上拼接(concatenate)张量的函数。
参数 dim=1 表示沿着第一个维度(即列维度)进行拼接。参数 dim=0 表示沿着行维度进行拼接。
假设你有两个张量 tensor1 和 tensor2:
import torch
tensor1 = torch.tensor([[1, 2],
[3, 4]])
tensor2 = torch.tensor([[5, 6],
[7, 8]])
如果在列的维度上拼接这两个张量,使用 torch.cat(dim=1):
result = torch.cat([tensor1, tensor2], dim=1)
result为:
tensor([[1, 2, 5, 6],
[3, 4, 7, 8]])
若使用torch.cat(dim=0),它将在行的维度上进行拼接,即增加行数,列数不变。
result = torch.cat([tensor1, tensor2], dim=0)
result为:
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])