pytorch每日一学29(torch.cat())在给定的维度上对tensor进行拼接

本文详细介绍了PyTorch中的torch.cat()方法,该方法用于在指定维度上拼接给定的张量。文章解释了参数tensors、dim和out的作用,并通过示例展示了如何正确使用该方法,包括对复数张量的操作。
摘要由CSDN通过智能技术生成

第29个方法

torch.cat(tensors, dim=0, *, out=None) → Tensor

此方法的意思是对给定的tensor在指定的维度上进行拼接,可以看做是torch.split()
torch.chunk()
的逆向操作。

参数介绍:

  • tensors:要拼接的tensor,这里应该是多个tensor(此处只要是tensor就可以,包括复数tensor和量化tensor),例如如果将tensor a,b进行拼接,这里的输入应该是(a, b)。
  • dim:从哪个维度上对tensor进行拼接,0为第一个维度,1为第二个维度。
  • out:输出的tensor

注意: 对于进行拼接的两个张量,它们在除了拼接的维度上,其余维度上的形状应该相等。例如tensors=(a, b), dim=0那么a,b除了dim=1的维度,其余维度上的形状大小应该相等,不然会报错,因为无法进行拼接。

使用方法如下:

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,
         -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,
         -0.5790,  0.1497]])

当然此方法对复数tensor也适用:
在这里插入图片描述
指定dim形状不一样时照样可以拼接:
在这里插入图片描述
如果非指定维度形状不等就会报错:
在这里插入图片描述

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值