pytorch tensor合并与分割

本文详细介绍了PyTorch库中用于张量操作的四个关键函数:torch.cat用于沿指定维度连接张量,torch.stack用于在新的维度上堆叠张量,torch.split用于按大小或部分切分,而torch.chunk则按块数切割。这些函数在处理多维数据时非常实用。
摘要由CSDN通过智能技术生成

1. cat

torch.cat(tensors, dim=0, *, out=None) → Tensor
在指定维度上,连接给定tensor序列或empty,除连接的dimension外,所有得的ensor必须有相同的shape
参数:
tensors-具有相同类型的tensor序列,非empty tensor必须具有相同的shape,连接的dimension除外
dim-指定的连接的维度
输出:
连接后的tensor

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

上图分别是在列和行两个维度连接后的结果

2. stack

创建新维度来连接张量序列

torch.stack(tensors, dim=0, *, out=None) → Tensor

参数:
tensors-张量序列,具有相同的size
dim-插入的新维度,必须介于0和连接的tensor的维度之间
输出:
连接后的tensor
在这里插入图片描述
注意:cat和stack的区别
stack连接的tensor必须具有相同的size,否则报错,cat是除连接的维度外,其他维度shape必须相同
如下示例:
在这里插入图片描述

3. split

把一个tensor切分成块,每个块是原tensor的一部分

torch.split(tensor, split_size_or_sections, dim=0)

参数:
tensor-用来切分的tensor
split_size_or_sections (int) or (list(int)) -单个块的size后者是每个块size的list
dim (int) – 以tensor的哪个维度进行切分
输出:
Tuple[Tensor, …]

示例:
在这里插入图片描述

4. chunk

强制将一个tensor切分成指定数量的块,每个块是原tensor的一部分

torch.chunk(input, chunks, dim=0) → List of Tensors

参数:
input (Tensor) – 输入切分的tensor
chunks (int) – 切分块的数量
dim (int) – 以tensor的哪个维度进行切分
输出:
切分后的list

示例:
在这里插入图片描述
注意:split与chunk的区别
区别主要是第二个参数,split第二个参数切分块的size,而chunk是切分块的数量

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值