torch.cat函数
torch
中的cat
函数用于沿着指定维度
将张量连接起来。要求除了要连接的轴之外其他的轴必须具有相同的形状。
torch.cat((X, Y), dim=0), torch.cat((X, Y), dim=1)
给定一个包含多个张量的序列X和Y,通过指定dim
参数可以将它们沿着指定维度
连接在一起。
import torch
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
C = torch.cat((A, B), dim=0)
D = torch.cat((A, B), dim=1)
print(C)
print(D)
# 输出
# tensor([[1, 2],
# [3, 4],
# [5, 6],
# [7, 8]])
# tensor([[1, 2, 5, 6],
# [3, 4, 7, 8]])
在二维数据中,dim取值为(0,1)
- 当
dim = 0
时候,按行拼接
(或者叫列合并
),列数相同,拼接数据
的所有行
- 当
dim = 1
时候,按列拼接
(或者叫行合并
),行数相同,拼接数据
的所有列
在三维数据中,dim取值为(0,1,2)
- dim=0:表示沿着第0维度进行拼接。这意味着将两个包含多个矩阵的三维张量连接起来,形成一个更高的三维张量。
- dim=1:表示沿着第1维度进行拼接。这意味着将两个包含多个行向量的三维张量连接起来,形成一个更宽的三维张量。
- dim=2:表示沿着第2维度进行拼接。这意味着将两个包含多个列向量的三维张量连接起来,形成一个更深的三维张量。
三维比二维多了一个维度,0
维度。事实上,三维数据中的 1
和 2
维度,分别对应二维数据的 0
和1
维度,而三维数据中的 0
维度,含义就是 有多少个二维数据
,
比如 :4x3x2
含义就是 4
个 3x2
的矩阵
。
import torch
A = torch.tensor([[[1, 2], [3, 4]],
[[5, 6], [7, 8]]])
B = torch.tensor([[[9, 10], [11, 12]],
[[13, 14], [15, 16]]])
print(A.shape) # 输出:torch.Size([2, 2, 2])
print(B.shape) # 输出:torch.Size([2, 2, 2])
C = torch.cat((A, B), dim=0)
print(C.shape) # 输出:torch.Size([4, 3, 2])
print(C)
# 输出
torch.Size([2, 2, 2])
torch.Size([2, 2, 2])
torch.Size([4, 2, 2])
tensor([[[ 1, 2],
[ 3, 4]],
[[ 5, 6],
[ 7, 8]],
[[ 9, 10],
[11, 12]],
[[13, 14],
[15, 16]]])
广播机制
广播机制是在运算过程中,去处理两个形状不同向量的一种手段
a = torch.tensor([[0],
[1],
[2]])
b = torch.tensor([[0, 1]])
print(a.shape) # torch.Size([3, 1])
print(b.shape) # torch.Size([1, 2])
print(a+b) # tensor([[0, 1],
[1, 2],
[2, 3]])
满足广播机制的条件: 按从右往左顺序看两个张量的每一个维度
- 这两个维度的大小相等
- 某个维度 一个张量有,一个张量没有
- 某个维度 一个张量有,一个张量也有但大小是1
不能进行广播:两个张量维度从右往左看,如果出现两个张量在某个维度位置上面,维度大小不相等,且两个维度大小没有一个是1,那么这两个张量一定不能进行广播。
索引与切片
冒号切片
- 单冒号:变量[ start : end ] 左闭右开
- 双冒号:变量[ start : end : step ] 步长可正可负