torch.split()方法

torch.split(tensorsplit_size_or_sectionsdim=0)

split_size_or_sections 为切分后的每块大小,不是切分为多少块

import torch

x = torch.randn(1, 2, 4, 4)
y = torch.split(x, 1, dim=1)    # 每块大小为1
# print(x[0])
for i in y:
    print(i.size())

a = torch.rand(1, 4, 8, 6)
b = torch.split(a, 2, dim=1)    # 每块大小为2
for i in b:
    print(i.size())


 

©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页