【Torch API】pytorch 中index_copy_函数详解

Tensor.index_copy_按照index索引,将Tensor的元素复制到张量中

 

    def index_copy_(self, dim, index, tensor): # real signature unknown; restored from __doc__
        """
        index_copy_(dim, index, tensor) -> Tensor
        
        Copies the elements of :attr:`tensor` into the :attr:`self` tensor by selecting
        the indices in the order given in :attr:`index`. For example, if ``dim == 0``
        and ``index[i] == j``, then the ``i``\ th row of :attr:`tensor` is copied to the
        ``j``\ th row of :attr:`self`.
        
        The :attr:`dim`\ th dimension of :attr:`tensor` must have the same size as the
        length of :attr:`index` (which must be a vector), and all other dimensions must
        match :attr:`self`, or an error will be raised.
        
        .. note::
            If :attr:`index` contains duplicate entries, multiple elements from
            :attr:`tensor` will be copied to the same index of :attr:`self`. The result
            is nondeterministic since it depends on which copy occurs last.
        
        Args:
            dim (int): dimension along which to index
            index (LongTensor): indices of :attr:`tensor` to select from
            tensor (Tensor): the tensor containing values to copy
        
        Example::
        
            >>> x = torch.zeros(5, 3)
            >>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float)
            >>> index = torch.tensor([0, 4, 2])
            >>> x.index_copy_(0, index, t)
            tensor([[ 1.,  2.,  3.],
                    [ 0.,  0.,  0.],
                    [ 7.,  8.,  9.],
                    [ 0.,  0.,  0.],
                    [ 4.,  5.,  6.]])
        """
        return _te.Tensor(*(), **{})
  • dim(int) 需要插入的索引维度
  • index(张量) 被插入的索引
  • tensor 包含要复制的值的张量

官方注释例子讲解

这里x是指被复制的张量 ,如例子中是一个 5×3 的矩阵
t 里面包含了需要被插入的值, 如

t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]

我们希望把t[0]t[1]t[2] 分别复制到 x的 第0、第4、第2 维度
将index设置为 index = torch.tensor([0, 4, 2]) 即可

官方例子如下:

x = torch.zeros(5, 3)
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float)
index = torch.tensor([0, 4, 2])
x.index_copy_(0, index, t)

输出

tensor([[ 1.,  2.,  3.],
                [ 0.,  0.,  0.],
                [ 7.,  8.,  9.],
                [ 0.,  0.,  0.],
                [ 4.,  5.,  6.]])

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值