torch 中的 dim 参数,以及 torch.repeat_interleave() 函数

本文介绍了张量中的维度概念,通过实例展示了如何使用torch生成二维张量,并解释了张量的dim参数。接着详细解析了torch.repeat_interleave()函数,包括其参数作用和不同用法,如指定维度复制、不指定维度时的默认行为以及使用张量作为repeat参数。最后提供了官方文档链接以供深入学习。
摘要由CSDN通过智能技术生成

torch 中的 dim 参数,以及torch.repeat_interleave()

张量的维度:用 dim 描述

在线性代数课程中,接触的最多的是以行和列对矩阵进行描述,但是在拥有更多维度的张量中则有所不同

我们使用 torch 生成一个张量

y = torch.tensor([[1, 2], [3, 4]]) 

利用数学形式的表示,我们可以写作
y = [ 1 3 2 4 ] y = \begin{bmatrix}1 & 3 \\ 2 & 4\end{bmatrix} y=[1234]
或者以 torch 的形式描述为
y = [ [ 1 2 ] [ 3 4 ] ] y = \begin{bmatrix} \begin{bmatrix} 1 \\2\end{bmatrix}\begin{bmatrix}3 \\4\end{bmatrix}\end{bmatrix} y=[[12][34]]
第一层括号记作 dim=0,遍历 dim=0 的层次,得到两个张量
y 1 = [ 1 2 ] y 2 = [ 3 4 ] y_1 = \begin{bmatrix} 1 \\ 2\end{bmatrix} \\ y_2 =\begin{bmatrix} 3 \\ 4\end{bmatrix} y1=[12]y2=[34]
第二重括号记作 dim=1,遍历 dim=1 的层次,得到四个数字
y 11 = 1 y 12 = 2 y 21 = 3 y 22 = 4 y_{11} = 1 \\y_{12} =2 \\y_{21} = 3 \\ y_{22} = 4 y11=1y12=2y21=3y22=4

张量的复制:torch.repeat_interleave()

首先复制一下 PyTorch 官方文档的描述

torch.repeat_interleave(input, repeats, dim=None, *, output_size=None) → Tensor

参数有三个

  • input:待处理的 tensor
  • repeat:每个元素重复几次
  • dim:在那个维度上重复

看几个例子,以上一节中的 y y y 为例

dim=0

 torch.repeat_interleave(y, 2, dim=0)

dim=0,在 dim=0,即所有张量上进行操作,每个一维 tensor 复制 1 次

tensor([[1, 2],
        [1, 2],
        [3, 4],
       	[3, 4]])

dim=1

torch.repeat_interleave(y, 3, dim=1)

dim=1 ,即每个一维张量内部元素上进行操作,所有元素复制 2 次

tensor([[1, 1, 1, 2, 2, 2],
        [3, 3, 3, 4, 4, 4]])

不提供 dim 参数

y.repeat_interleave(2)

将默认返回 y.Flatten()

tensor([1, 1, 2, 2, 3, 3, 4, 4])

repeat 参数本身也能是一个 tensor

torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)

dim=0,即所有张量上进行操作,第一个 tensor 复制 0 次,第二个复制 1 次

tensor([[1, 2],
        [3, 4],
        [3, 4]])

最后贴上官方文档的链接

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值