拼接张量
torch.cat((cls_token, x), dim=1)
dim = 1:
import torch
a = torch.tensor([[1, 2],
[3, 4]])
b = torch.tensor([[5, 6],
[7, 8]])
result = torch.cat((a, b), dim=1)
# 在第二个维度上进行拼接
#在[[1, 2], [3, 4]]的第二个"["进行拼接
print(result)
#tensor([[1, 2, 5, 6],
# [3, 4, 7, 8]])