前言
torch.split(tensor, split_size_or_sections, dim=0)
"""
tensor:输入张量,要进行分割的张量。
split_size_or_sections:指定分割的方式。可以是一个整数 split_size,表示每个子张量的大小;或者是一个整数列表 split_sizes,表示每个子张量的大小;或者是一个整数值 sections,表示要分#割成的子张量的数量。
dim:指定进行分割的维度。
"""
一、在维度N(批量大小)上进行切片
import torch
input = torch.randn(4, 3, 28, 28)
output = torch.split(input, split_size_or_sections=2, dim=0)
print(output)
(tensor([[[[...]]],
[[...]]]),
tensor([[[[...]]],
[[...]]]))
二、在维度C(通道)上进行切片
import torch
input = torch.randn(4, 6, 28, 28)
output = torch.split(input, split_size_or_sections=3, dim=1)
print(output)
(tensor([[[[...]]],
[[...]]]),
tensor([[[[...]]],
[[...]]]))
三、在维度 H 上进行切片
import torch
input = torch.randn(4, 3, 28, 28)
output = torch.split(input, split_size_or_sections=14, dim=2)
print(output)
(tensor([[[[...]]],
[[...]]]),
tensor([[[[...]]],
[[...]]]))
四、在维度 W 上进行切片
import torch
input = torch.randn(4, 3, 28, 28)
output = torch.split(input, split_size_or_sections=14, dim=3)
print(output)
(tensor([[[[...]]],
[[...]]]),
tensor([[[[...]]],
[[...]]]))