引言:看了关于torch.cat函数的文章,有点乱,自己总结一篇,关于四维tensor合并。
- 一张图像在计算机中的表示通常为三维tensor(张量),即[channels,height,width] 。也就是一张彩色图片通常有三色通道(R,G,B)组成,高和宽也就是常说的照片大小,比如224x224
- 在图像处理的时候会增加一个变量batch_size,也就是把多少张图片作为一批进行处理。所以就变成了四维张量,即[batch_size,channels,heigth,width],也即是[批量大小,通道数,高,宽]
- 如何判断一个tensor是几维张量最简单的办法就是看中括号数。例如 [[[[1,2,3]]]],是四维张量。
- torch.cat()函数,官方文档是这样写的
torch.
cat
(tensors, dim=0, *, out=None),也就是有两个参数,一个是要合并的张量,一个是在哪个维度上进行合并。
废话少说开始演示。
import torch
a=torch.tensor([[[[1,1,1],[2,2,2]]]])
b=torch.tensor([[[[3,3,3],[4,4,4]]]])
print(a.shape,b.shape)
#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])
定义了两个四维张量。维度都为[1,1,2,3],即批量大小为1,通道为1,高为2,宽为3