- torch.chunck()函数是将指定维度的张量拆成n份。
- torch.split()函数是将指定维度的张量等步长拆分,具体拆分几份看步长。
import torch
import numpy as np
def tensor_chunck():
arr = np.ones([3, 3,3])
t = torch.tensor(arr)
print(t)
# torch.chunck()函数是将指定维度的张量拆成2份
t1 = torch.chunk(t,2, dim=2)
print(t1,t1[0].shape,t1[1].shape)
def tensor_split():
arr = np.ones((3,3,3))
t= torch.tensor(arr)
t1 = torch.split(t,1,dim=0)
print(t1)
arr1 = np.ones((3, 3, 3))
t = torch.tensor(arr1)
t1 = torch.split(t, [1,2], dim=0)
print(t1)
if __name__ == '__main__':
tensor_chunck()
tensor_split()