在 PyTorch 中,对张量 (Tensor) 进行拆分通常会用到两个函数:
torch.split [按块大小拆分张量]
torch.chunk [按块数拆分张量]
而对张量 (Tensor) 进行拼接通常会用到另外两个函数:
torch.cat [按已有维度拼接张量]
torch.stack [按新维度拼接张量]
张量的拆分
torch.split() 按照块的大小进行划分
import torch
#定义一个四维张量
x = torch.randn(1, 64,32,32)
#按照维度1 即64 按照块大小为4进行划分,一共划分了16个块,返回的是一个列表
o1=torch.split(x, 4, dim = 1)
print(x.shape)
print(len(o1))
print(o1[1].shape)
torch.chunk() 按照块数进行划分
import torch
#定义一个四维张量
x = torch.randn(1, 64,32,32)
#按照维度1 即64 按照块的数目为4进行划分,划分每个块的大小是16,返回的是一个列表
o1=torch.chunk(x, 4, dim = 1)
print(x.shape)
print(len(o1))
print(o1[1].shape)
张量的拼接
torch.cat(tensor,dim) 在已有的维度上进行拼接
import torch
#定义两个个四维张量
x = torch.randn(1, 64,32,32)
y = torch.randn(1,64, 32,32)
#在