torch.cat()用法

torch.cat()用法

torch.cat(tensors, dim=0, out=None)

用于将给定的一个tensor序列在指定维度进行拼接(常用在将一个tensor列表进行拼接).(通过下面的示例可以更好的理解)

cat操作不会增加维度

限制条件:除非有些序列里面的有些tensor为空,其余的所有tensor在拼接维度以外的其他维度的size都要一致

参数

  • tensors(sequence of Tensors) 输入的是一个tensor的序列,要遵循上述的限制条件
  • dim(int, optional) 指定在哪个维度进行拼接
  • out(Tensor, optional) 输出的拼接好的tensor(可有可无)

示例

x = torch.randn(2, 3)
y = [x, x, x]  # 构造一个tensor序列
y
>>>[tensor([[-0.2029, -1.2860,  1.0803],
            [ 0.4547,  0.4816,  0.6233]]),
 tensor([[-0.2029, -1.2860,  1.0803],
         [ 0.4547,  0.4816,  0.6233]]),
 tensor([[-0.2029, -1.2860,  1.0803],
         [ 0.4547,  0.4816,  0.6233]])]
# 在维度0进行拼接,他们的维度1的shape是一致的,(这边举的例子维度0正好是一致的都是2,如果不一致比如一个是1一个是2也是可以进行拼接的)
torch.cat(y, 0)
>>>tensor([[-0.2029, -1.2860,  1.0803],
           [ 0.4547,  0.4816,  0.6233],
           [-0.2029, -1.2860,  1.0803],
           [ 0.4547,  0.4816,  0.6233],
           [-0.2029, -1.2860,  1.0803],
           [ 0.4547,  0.4816,  0.6233]])
# 在维度1进行拼接,他们的维度0的shape是一致的
torch.cat(y, 1)
>>>tensor([[-0.2029, -1.2860,  1.0803, -0.2029, -1.2860,  1.0803, -0.2029, -1.2860,1.0803],
           [ 0.4547,  0.4816,  0.6233,  0.4547,  0.4816,  0.6233,  0.4547,  0.4816,0.6233]])
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值