文章目录
张量拼接
使用torch.cat()
功能:将张量按维度dim进行拼接
示例:
t = torch.ones((2, 3))
t_0 = torch.cat([t, t], dim=0)
t_1 = torch.cat([t, t, t], dim=1)
print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape, t_1, t_1.shape))
结果:
t_0:tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]) shape:torch.Size([4, 3])
t_1:tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1.]]) shape:torch.Size([2, 9])
使用torch.stack()
功能:新创建一个维度dim,然后在新维度dim上进行拼接
示例:
t = torch.ones((2, 3))
t_stack = torch.stack([t, t, t], dim=0)
print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))
结果:
t_stack:tensor([[[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.]]]) shape:torch.Size([3, 2, 3])
对比
cat不会拓展维度,而stack一定会创建一个新维度,会拓展维度
张量切分
使用torch.chunk()
功能:按照指定维度dim进行平均切分,如果不能整除,那最后一份张量的结果小于平均张量,最后一份张量在dim维度上的大小为余数
示例:
a