torch.split() 与 torch.chunk()

torch.split()

torch.chunk()

区别


两者都是切分tensor操作,有一些略微的不同。

torch.split()

官网:https://pytorch.org/docs/stable/torch.html#torch.split

torch.split(tensorssplit_size_or_sectiondim=0)

 torch.split()作用将tensor分成块结构。

参数:

tesnor:input,待分输入

split_size_or_sections:需要切分的大小(int or list )

dim:切分维度

output:切分后块结构 <class 'tuple'>

当split_size_or_sections为int时,tenor结构和split_size_or_sections,正好匹配,那么ouput就是大小相同的块结构。如果按照split_size_or_sections结构,tensor不够了,那么就把剩下的那部分做一个块处理。

当split_size_or_sections 为list时,那么tensor结构会一共切分成len(list)这么多的小块,每个小块中的大小按照list中的大小决定,其中list中的数字总和应等于该维度的大小,否则会报错(注意这里与split_size_or_sections为int时的情况不同)。

例子:

split_size_or_sections为int型时。

import torch

x = torch.rand(4,8,6)
y = torch.split(x,2,dim=0) #按照4这个维度去分,每大块包含2个小块
for i in y :
    print(i.size())

output:
torch.Size([2, 8, 6])
torch.Size([2, 8, 6])

y = torch.split(x,3,dim=0)#按照4这个维度去分,每大块包含3个小块
for i in y:
    print(i.size())

output:
torch.Size([3, 8, 6])
torch.Size([1, 8, 6])

split_size_or_sections为list型时。

import torch

x = torch.rand(4,8,6)
y = torch.split(x,[2,3,3],dim=1)
for i in y:
    print(i.size())

output:
torch.Size([4, 2, 6])
torch.Size([4, 3, 6])
torch.Size([4, 3, 6])


y = torch.split(x,[2,1,3],dim=1) #2+1+3 等于8,报错
for i in y:
    print(i.size())

output:
split_with_sizes expects split_sizes to sum exactly to 8 (input tensor's size at dimension 1), but got split_sizes=[2, 1, 3]

torch.chunk()

官网:https://pytorch.org/docs/stable/torch.html#torch.chunk

torch.chunk(inputchunksdim=0) → List of Tensors

参数:input需要切分的tensor,chunks(int型)需要切分后的块大小,dim切分的维度。

其基本使用和torch.split()相同。

import torch
x = torch.rand(2,4,6)
a1 = torch.chunk(x,2,dim=1)[0]
a2 = torch.split(x,2,dim=1)[0]
print(torch.equal(a1,a2))

output:
True

区别:

(1)chunks只能是int型,而split_size_or_section可以是list。

(2)chunks在时,不满足该维度下的整除关系,会将块按照维度切分成1的结构。而split会报错。

例子:

import torch
x = torch.rand(2,4,6)
print(torch.chunk(x,5,dim=1)[0].size()) 
### 4不能整除5,返回4个大小为[2, 1, 6]的块,即做块大小为1的切分

output:
torch.Size([2, 1, 6])

print(torch.split(x,5,dim=1)[0].size()) 
### 报错
torch.cat()是一个将多个张量连接起来的函数。它可以看作是torch.split()和torch.chunk()的逆操作。torch.split()函数可以将一个张量分割成指定尺寸或指定个数的小张量,而torch.cat()函数则可以将这些小张量按照指定的维度连接起来。 举个例子来说明,假设有一个2x3的张量x: ``` >>> x = torch.randn(2, 3) >>> x tensor([[ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497]]) ``` 如果使用torch.cat((x, x, x), 0),将会按照行的方向连接三个x张量,得到一个6x3的张量: ``` >>> torch.cat((x, x, x), 0) tensor([[ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497], [ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497], [ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497]]) ``` 而如果使用torch.cat((x, x, x), 1),将会按照列的方向连接三个x张量,得到一个2x9的张量: ``` >>> torch.cat((x, x, x), 1) tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497]]) ``` 因此,torch.cat()函数可以将多个张量按照指定的维度连接在一起,得到一个更大的张量。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* *3* [pytorch--torch.cat() & torch.split()](https://blog.csdn.net/weixin_42468475/article/details/115336652)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 100%"] [ .reference_list ]
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Foneone

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值