torch.chunk()——数组的拆分

torch.chunk()——数组的拆分

torch.chunk(input, chunks, dim=0) → List of Tensors

功能:将数组拆分为特定数量的块

输入:

  • input:待拆分的数组
  • chunks:拆分的块数,指定为几,就拆成几
  • dim:拆分的维度,默认沿第1维度拆分

注意:

  • 函数最后返回的是元组类型,包含拆分后的数组

  • 如果输入的数组在指定的维度下不能整除,则拆分得到的最后一块数组的dim维度大小将小于前面所有的数组dim维度大小

  • chunks最大值限制,如果指定的块数超过最大值,则最终只能拆分成最大值数量

  • chunks最大值的计算,input数组在dim维度上大小为a
    c h u n k s m a x = { a 2 , i f a 为偶数 a + 1 2 , i f a 为奇数 chunks_{max}=\left \{ \begin{matrix} \frac{a}{2} \quad&,if\quad a为偶数 \\ \frac{a+1}{2}\quad&,if \quad a为奇数 \end{matrix} \right. chunksmax={2a2a+1,ifa为偶数,ifa为奇数

具体见代码案例

代码案例

import torch
a=torch.arange(20).view(4,5)
b=torch.chunk(a,chunks=2,dim=0)
c=torch.chunk(a,chunks=2,dim=1)
print(type(b))
print(a.shape)
print(len(b))
print(len(c))
print(a)
for i in range(len(b)):
    print(b[i])
    print(b[i].shape)
    # 输出拆分后的形状
for i in range(len(c)):
    print(c[i])
    print(c[i].shape)

输出

# 拆分后返回的是元组类型
<class 'tuple'>
# 拆分前数组形状
torch.Size([4, 5])
# chunks指定为2,无论在哪个维度拆分,都会得到2个数组
2
2
# 拆分前数组
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]])
# 当dim=0,即在第一维度拆分时
tensor([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]])
# 第一维度相当于做了一个除法,除以chunks
# 在这里4除以2等于2,所以拆分后,每个数组第一维度大小是2
torch.Size([2, 5])
tensor([[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]])
torch.Size([2, 5])
# 当dim=1,即在第二维度拆分时
tensor([[ 0,  1,  2],
        [ 5,  6,  7],
        [10, 11, 12],
        [15, 16, 17]])
# 当不能整除的时候,最后一个数组在对应维度的尺寸将会比前面的小。
tensor([[ 3,  4],
        [ 8,  9],
        [13, 14],
        [18, 19]])
# 这里最后一个数组第二维度是2,前面的数组维度是3

如果chunks指定的过大

import torch
a=torch.arange(18).view(2,9)
# 拆奇数
b=torch.chunk(a,chunks=8,dim=1)
c=torch.arange(16).view(2,8)
# 拆偶数
d=torch.chunk(c,chunks=7,dim=1)
for i in range(len(b)):
    print(b[i])
for i in range(len(d)):
    print(d[i])

输出

# 9分成8块,最多得到5块

tensor([[ 0,  1],
        [ 9, 10]])
tensor([[ 2,  3],
        [11, 12]])
tensor([[ 4,  5],
        [13, 14]])
tensor([[ 6,  7],
        [15, 16]])
tensor([[ 8],
        [17]])
# 8分成7块,最多得到4块
tensor([[0, 1],
        [8, 9]])
tensor([[ 2,  3],
        [10, 11]])
tensor([[ 4,  5],
        [12, 13]])
tensor([[ 6,  7],
        [14, 15]])

换句话说:除了最后一个数组dim维度上的大小可以为1,前面的数组dim维度上的大小至少是2

官方文档

torch.chunk():https://pytorch.org/docs/stable/generated/torch.chunk.html#torch.chunk

点个赞再走吧

  • 8
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

视觉萌新、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值