pytorch torch.split() 与 torch.chunk()

原文链接:https://blog.csdn.net/foneone/article/details/103875250

torch.chunk()

区别


两者都是切分tensor操作,有一些略微的不同。

torch.split()

官网:https://pytorch.org/docs/stable/torch.html#torch.split

torch.split(tensorssplit_size_or_sectiondim=0)

 torch.split()作用将tensor分成块结构。

参数:

tesnor:input,待分输入

split_size_or_sections:需要切分的大小(int or list )

dim:切分维度

output:切分后块结构 <class 'tuple'>

当split_size_or_sections为int时,tenor结构和split_size_or_sections,正好匹配,那么ouput就是大小相同的块结构。如果按照split_size_or_sections结构,tensor不够了,那么就把剩下的那部分做一个块处理。

当split_size_or_sections 为list时,那么tensor结构会一共切分成len(list)这么多的小块,每个小块中的大小按照list中的大小决定,其中list中的数字总和应等于该维度的大小,否则会报错(注意这里与split_size_or_sections为int时的情况不同)。

例子:

split_size_or_sections为int型时。


 
 
  1. import torch
  2. x = torch.rand( 4, 8, 6)
  3. y = torch.split(x, 2,dim= 0) #按照4这个维度去分,每大块包含2个小块
  4. for i in y :
  5. print(i.size())
  6. output:
  7. torch.Size([ 2, 8, 6])
  8. torch.Size([ 2, 8, 6])
  9. y = torch.split(x, 3,dim= 0) #按照4这个维度去分,每大块包含3个小块
  10. for i in y:
  11. print(i.size())
  12. output:
  13. torch.Size([ 3, 8, 6])
  14. torch.Size([ 1, 8, 6])

split_size_or_sections为list型时。


 
 
  1. import torch
  2. x = torch.rand( 4, 8, 6)
  3. y = torch.split(x,[ 2, 3, 3],dim= 1)
  4. for i in y:
  5. print(i.size())
  6. output:
  7. torch.Size([ 4, 2, 6])
  8. torch.Size([ 4, 3, 6])
  9. torch.Size([ 4, 3, 6])
  10. y = torch.split(x,[ 2, 1, 3],dim= 1) #2+1+3 等于8,报错
  11. for i in y:
  12. print(i.size())
  13. output:
  14. split_with_sizes expects split_sizes to sum exactly to 8 (input tenso r's size at dimension 1), but got split_sizes=[2, 1, 3]

torch.chunk()

官网:https://pytorch.org/docs/stable/torch.html#torch.chunk

torch.chunk(inputchunksdim=0) → List of Tensors

参数:input需要切分的tensor,chunks(int型)需要切分后的块大小,dim切分的维度。

其基本使用和torch.split()相同。


 
 
  1. import torch
  2. x = torch.rand( 2, 4, 6)
  3. a1 = torch.chunk(x, 2,dim= 1)[ 0]
  4. a2 = torch.split(x, 2,dim= 1)[ 0]
  5. print(torch.equal(a1,a2))
  6. output:
  7. True

区别:

(1)chunks只能是int型,而split_size_or_section可以是list。

(2)chunks在时,不满足该维度下的整除关系,会将块按照维度切分成1的结构。而split会报错。

例子:


 
 
  1. import torch
  2. x = torch.rand( 2, 4, 6)
  3. print(torch.chunk(x, 5,dim= 1)[ 0].size())
  4. ### 4不能整除5,返回4个大小为[2, 1, 6]的块,即做块大小为1的切分
  5. output:
  6. torch.Size([ 2, 1, 6])
  7. print(torch.split(x, 5,dim= 1)[ 0].size())
  8. ### 报错
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
`.split()` 和 `.chunk()` 都是用来将一个张量按照指定的维度进行切分的方法。 `.split()` 方法将一个张量按照指定的维度切分成多个小张量,返回一个元组,其中每个元素是一个小张量。使用方法为: ```python torch.split(tensor, split_size_or_sections, dim=0) ``` 其中 `tensor` 是待切分的张量,`split_size_or_sections` 可以是一个整数,表示每个小张量的大小,或者是一个元组,表示每个小张量的大小。`dim` 表示要切分的维度。例如,如果要将一个张量按照第1维切分成大小为2的小张量,可以这样写: ```python import torch x = torch.randn(4, 2) splits = torch.split(x, split_size_or_sections=2, dim=0) print(splits) ``` 输出: ``` (tensor([[-0.7967, -0.5588], [ 0.7187, 2.0854]]), tensor([[ 0.4067, -1.0582], [ 0.6215, 0.8995]])) ``` `.chunk()` 方法与 `.split()` 方法类似,也是将一个张量按照指定维度切分成多个小张量,但是 `.chunk()` 方法将切分后的小张量平均分配到多个元组中,并返回一个元组,其中每个元素是一个包含小张量的元组。使用方法为: ```python torch.chunk(tensor, chunks, dim=0) ``` 其中 `tensor` 是待切分的张量,`chunks` 是要分成的小张量的个数,`dim` 表示要切分的维度。例如,如果要将一个张量按照第1维切分成2个小张量,可以这样写: ```python import torch x = torch.randn(4, 2) chunks = torch.chunk(x, chunks=2, dim=0) print(chunks) ``` 输出: ``` (tensor([[-0.7967, -0.5588], [ 0.7187, 2.0854]]), tensor([[ 0.4067, -1.0582], [ 0.6215, 0.8995]])) ``` `.split()` 和 `.chunk()` 的相同点是都可以将一个张量按照指定的维度切分成多个小张量。不同点是,`.split()` 方法可以将小张量的大小指定为任意值,而 `.chunk()` 方法将小张量平均分配到多个元组中。另外,`.split()` 方法返回一个元组,其中每个元素是一个小张量,而 `.chunk()` 方法返回一个元组,其中每个元素是一个包含小张量的元组。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值