-
torch.cat((input1, input2, ... ), dim=?)
torch.cat()可以将多个tensor在dim维度上进行拼接。如下:x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int) x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int) cat1 = torch.cat((x1, x2),0) cat2 = torch.cat((x1, x2),1) print(cat1,cat2) # 输出 tensor([[11, 21, 31], [21, 31, 41], [12, 22, 32], [22, 32, 42]], dtype=torch.int32) tensor([[11, 21, 31, 12, 22, 32], [21, 31, 41, 22, 32, 42]], dtype=torch.int32)
-
.shape[i]
shape函数的功能是读取tensor某个维度的长度
对于图像来说:
image.shape[0]——图片高
image.shape[1]——图片长
image.shape[2]——图片通道数
而对于矩阵来说:
shape[0]:表示矩阵的行数
shape[1]:表示矩阵的列数
注:-1代表最后一个,所以shape[-1]代表最后一个维度,如在二维张量里,shape[-1]表示列数
x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int)
print(x1.shape[0])
print(x1.shape[1])
print(x1.shape[-1])
# 输出
2
3
3