torch 中的 dim 参数,以及torch.repeat_interleave()
张量的维度:用 dim
描述
在线性代数课程中,接触的最多的是以行和列对矩阵进行描述,但是在拥有更多维度的张量中则有所不同
我们使用 torch 生成一个张量
y = torch.tensor([[1, 2], [3, 4]])
利用数学形式的表示,我们可以写作
y
=
[
1
3
2
4
]
y = \begin{bmatrix}1 & 3 \\ 2 & 4\end{bmatrix}
y=[1234]
或者以 torch 的形式描述为
y
=
[
[
1
2
]
[
3
4
]
]
y = \begin{bmatrix} \begin{bmatrix} 1 \\2\end{bmatrix}\begin{bmatrix}3 \\4\end{bmatrix}\end{bmatrix}
y=[[12][34]]
第一层括号记作 dim=0
,遍历 dim=0
的层次,得到两个张量
y
1
=
[
1
2
]
y
2
=
[
3
4
]
y_1 = \begin{bmatrix} 1 \\ 2\end{bmatrix} \\ y_2 =\begin{bmatrix} 3 \\ 4\end{bmatrix}
y1=[12]y2=[34]
第二重括号记作 dim=1
,遍历 dim=1
的层次,得到四个数字
y
11
=
1
y
12
=
2
y
21
=
3
y
22
=
4
y_{11} = 1 \\y_{12} =2 \\y_{21} = 3 \\ y_{22} = 4
y11=1y12=2y21=3y22=4
张量的复制:torch.repeat_interleave()
首先复制一下 PyTorch
官方文档的描述
torch.repeat_interleave(input, repeats, dim=None, *, output_size=None) → Tensor
参数有三个
input
:待处理的 tensorrepeat
:每个元素重复几次dim
:在那个维度上重复
看几个例子,以上一节中的 y y y 为例
dim=0
torch.repeat_interleave(y, 2, dim=0)
在 dim=0
,在 dim=0
,即所有张量上进行操作,每个一维 tensor 复制 1 次
tensor([[1, 2],
[1, 2],
[3, 4],
[3, 4]])
dim=1
torch.repeat_interleave(y, 3, dim=1)
在 dim=1
,即每个一维张量内部元素上进行操作,所有元素复制 2 次
tensor([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4]])
不提供 dim
参数
y.repeat_interleave(2)
将默认返回 y.Flatten()
tensor([1, 1, 2, 2, 3, 3, 4, 4])
repeat
参数本身也能是一个 tensor
torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
在 dim=0
,即所有张量上进行操作,第一个 tensor 复制 0 次,第二个复制 1 次
tensor([[1, 2],
[3, 4],
[3, 4]])
最后贴上官方文档的链接