张量操作
一、拼接与切分
1.1 torch.cat()
功能:将张量按维度dim进行拼接
tensors:张量序列
dim:要拼接的维度
函数:torch.cat(tensors,dim=0,out=None)
t = torch.ones([2,3])
t_0 = torch.cat([t,t],dim=0)
t_1 = torch.cat([t,t],dim=1)
print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0,t_0.shape,t_1,t_1.shape))
输出:shape:torch.Size([4, 3])
shape:torch.Size([2, 6])
1.2 torch.stack()
功能:在新创建的维度dim上进行拼接
tensors:张量序列
dim:要拼接的维度
函数:torch.stack(tensors,dim=0,out=None)
t_stack = torch.stack([t,t,t,t,t],dim=1)
print("\nt_stack:{} shape:{}".format(t_stack,t_stack.shape))
输出:torch.Size([2, 5, 3])
1.3 torch.chunk()
功能:将张量按维度dim进行平均切分
返回值:张量列表
注意事项:若不能整除,最后一份张量小于其他张量
input:要切分的张量
chunks:要切分额份数
函数:torch.chunk(input,chunks,dim=0)
【注】chunk向上取整
a = torch.ones((2,5))
list_of_tensor = torch.chunk(a,dim=1,chunks