文章目录
Ⅰ.split([tensor a, tensor b, …], dim = i) 助记:“根据长度不平均拆分”
一、作用
1-1. 将Tensor第i个维度上的维度数拆分为a、b、…,从而生成n个新的Tensor。要保证实际可拆。
二、常用
2-1. 见In[4]
三、特殊
3-1. 无
四、代码
In[2]: import torch
In[3]: c = torch.rand(4, 3, 2)
In[4]: a, b =c.split([3, 1], dim = 0)
In[5]: a.shape, b.shape
Out[5]: (torch.Size([3, 3, 2]), torch.Size([1, 3, 2]))
Ⅱ.split(tensor a, dim = i) 助记:“根据长度平均拆分”
一、作用
1-1. 将Tensor第i个维度上的维度数用a平均拆分。要确保Tensor维度上的维度数能整除a,使生成的Tensor个数为整数。
二、常用
2-1. 见In[4]
三、特殊
3-1. 无
四、代码
In[2]: import torch
In[3]: c = torch.rand(4, 3, 2)
In[4]: a, b, d, e = c.split(1, dim = 0)
In[5]: a.shape, b.shape, d.shape, e.shape
Out[5]:
(torch.Size([1, 3, 2]),
torch.Size([1, 3, 2]),
torch.Size([1, 3, 2]),
torch.Size([1, 3, 2]))
Ⅲ.chunk(tensor a, dim = i) 助记:“根据生成Tensor数量平均拆分”
一、作用
1-1. 将Tensor第i个维度上的维度数 ÷ a平均拆分。要确保Tensor维度上的维度数能整除a,使生成的Tensor上的tensor为整数。
二、常用
2-1. 见In[9]
三、特殊
3-1. 无
四、代码
In[8]: c = torch.rand(6, 3, 2)
In[9]: a, b = c.chunk(2, dim = 0)
In[10]: a.shape, b.shape
Out[10]: (torch.Size([3, 3, 2]), torch.Size([3, 3, 2]))