【torch算子】torch.index_put和torch.Tensor.index_put_ 的理解

【算子释义】

index_put_(indices, value, accumulate=False) ——> Tensor  # Tensor赋值其他tensor的同时改变自己的value操作
# Puts values from the tensor value into the tensor self using the indices specified in indices (which is a tuple of Tensors). The expression tensor.index_put_(indices, value) is equivalent to tensor[indices] = value. Returns self.
# If accumulate is True, the elements in value are added to self. If accumulate is False, the behavior is undefined if indices contain duplicate elements.

# Parameters
	# indices (tuple of LongTensor) – tensors used to index into self.
	# value (Tensor) – tensor of same dtype as self.
	# accumulate (bool) – whether to accumulate into self
index_put(tensor1, indices, value, accumulate=False) ——> Tensor  # Tensor赋值其tensor,并不改变自己的value操作
# Out-place version of index_put_(). tensor1 corresponds to self in torch.Tensor.index_put_().

【算子使用】

>>> import torch
>>> index = [torch.LongTensor([0,1,2,1]), torch.LongTensor([0,2,0,1])] # inputTensor第一个维度,第二个维度的索引值,其中index的value必须 < inputTensor的对应维度dim值
>>> input = torch.ones(3,3)  # shape -> (3,3)
>>> value = torch.Tensor([5,5,5,5]) # shape -> (4)

# index_put
>>> out_1 = input.index_put(index, value)
>>> index
[tensor([0, 1, 2, 1]), tensor([0, 2, 0, 1])] # 会在input的(0,0), (1,2), (2,0),(1,1) 位置进行赋值操作,见out_
>>> value
tensor([5., 5., 5., 5.])
>>> input   # 值未改变
tensor([[1., 1., 1.],  
        [1., 1., 1.],
        [1., 1., 1.]])
>>> out_1   # 输出的赋值后的tensor
tensor([[5., 1., 1.],  # (0, 0)的值置为5
        [1., 5., 5.],  # (1, 2)的值置为5   # (1, 1)的值置为5
        [5., 1., 1.]]) # (2, 0)的值置为5

# index_put_  例1 
>>> out_2 = input.index_put_(index, value)
>>> input
tensor([[5., 1., 1.],
        [1., 5., 5.],
        [5., 1., 1.]])
>>> out_2
tensor([[5., 1., 1.],
        [1., 5., 5.],
        [5., 1., 1.]])
>>> print(input==out_2)
tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])

# index_put_  例2 
>>> index_2 =  [torch.LongTensor([0,1,2,2])]
>>> input.shape
torch.Size([3, 3])
>>> value.shape
torch.Size([4])

>>> out_3 = input.index_put_(index_2, value) # vlue的shape和input的shape不符合广播操作,需要修改value的shape
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: shape mismatch: value tensor of shape [4] cannot be broadcast to indexing result of shape [4, 3]

#  index_put_  例3 
>>> value = torch.Tensor([3,3,3])
>>> input
tensor([[5., 1., 1.],
        [1., 5., 5.],
        [5., 1., 1.]])
>>> value
tensor([3., 3., 3.])
>>> index_2
[tensor([0, 1, 2, 2])]  # 修改(0,0-all), (1,0-all), (2,0-all), (2,0-all)的value
>>> out_3 = input.index_put_(index_2, value)  # input & out_3的值均会被value覆盖
>>> input
tensor([[3., 3., 3.],
        [3., 3., 3.],
        [3., 3., 3.]])
>>> out_3
tensor([[3., 3., 3.],
        [3., 3., 3.],
        [3., 3., 3.]])

#  index_put_  例4 
>>> index_3 = [torch.LongTensor([0,1])] # 改变第一个维度0,1维的值,其他维度所有的值
>>> input
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
>>> value = torch.Tensor([5])
>>> value
tensor([5.])
>>> index_3 = [torch.LongTensor([0,1])]
>>> out_4 = input.index_put_(index_3, value)
>>> input
tensor([[5., 5., 5.],
        [5., 5., 5.],
        [0., 0., 0.]])

【总结】

indices (tuple of LongTensor)  
# ---> 改变inputTensor第一个维度..第n个维度的索引值,其中indices 的value必须 < inputTensor的对应维度dim值,如果tuple中只有一个LongTensor则改变第一个维度对应索引位置的value,其他维的所有value默认均重新赋值value
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值