torch.split()

torch.split(tensor, split_size_or_sections, dim=0)

Splits the tensor into chunks. Each chunk is a view of the original tensor.

  • tensor (Tensor) – tensor to split. 输入

  • split_size_or_sections (int) or (list(int)) – size of a single chunk or list of sizes for each chunk

    split_size_or_sections可以是整数也可以是list,int时每个chunk的大小;list里面每个元素对应一个chunk的大小。

  • dim (int) – dimension along which to split the tensor. 切分维度

import torch

a = torch.arange(10).reshape(5,2)
print(a.shape)
aa = torch.split(a, 2)
print(aa)
b,c,d = torch.split(a, 2)
print(b.shape,c.shape,d.shape)
print('*' * 30)
e,f = torch.split(a, [1,4]) # 沿着dim=0的方向,切分为1,4大小的chunk
print(e.shape, f.shape)
print(e,f)

Result:

torch.Size([5, 2])
(tensor([[0, 1],
        [2, 3]]), tensor([[4, 5],
        [6, 7]]), tensor([[8, 9]]))
torch.Size([2, 2]) torch.Size([2, 2]) torch.Size([1, 2])
******************************
torch.Size([1, 2]) torch.Size([4, 2])
tensor([[0, 1]]) tensor([[2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])
Example

利用torch.split()来对图片按通道切片

# SHM 
trimap = torch.randn(1,3,512,512)
trimap_softmax = F.softmax(trimap, dim=1) ###
print(trimap_softmax.shape)
# paper: bs, fs, us
bg, fg, unsure = torch.split(trimap_softmax, 1, dim=1)
print(bg.shape, fg.shape, unsure.shape)

Result:

torch.Size([1, 3, 512, 512])
torch.Size([1, 1, 512, 512]) torch.Size([1, 1, 512, 512]) torch.Size([1, 1, 512, 512])

Reference:

[1] https://pytorch.org/docs/stable/generated/torch.split.html

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

烤粽子

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值