pytorch中scatter函数的用法

当在PyTorch中需要根据指定的索引来将值分散(scatter)到张量的特定位置时,可以使用scatter函数。这个函数在处理非连续索引的情况下非常有用。

函数写法为:

target.scatter(dim, index, src)
  • target:即目标张量,将在该张量上进行映射
  • src:即源张量,将把该张量上的元素逐个映射到目标张量上
  • dim:指定轴方向,定义了填充方式。对于二维张量,dim=0表示逐列进行行填充,而dim=1表示逐行进行列填充
  • index: 按照轴方向,在target张量中需要填充的位置

本篇文章演示一下scatter函数的用法:

1. 创建示例目标张量(target tensor),目标张量在dim维度上不小于源张量,其他维度上一般与源张量相同

import torch

# 创建一个大小为[4, 4]的零张量
target = torch.zeros(4, 4)

2. 创建索引张量(index tensor),该张量确定在目标张量中的哪些位置进行散点操作。索引张量的形状通常与源张量相同,但数据类型为整数。

# 创建一个大小为[4, 4]的索引张量
index = torch.tensor([[0, 1, 2, 3],
                     [1, 2, 3, 0],
                     [2, 3, 0, 1],
                     [3, 0, 1, 2]])

3. 创建源张量(src tensor),该张量包含你要分散到目标张量中的值。源张量的形状通常与目标张量相同。

# 创建一个大小为[4, 4]的值张量
values = torch.tensor([[1, 2, 3, 4],
                      [5, 6, 7, 8],
                      [9, 10, 11, 12],
                      [13, 14, 15, 16]])

4. 使用scatter函数进行散点操作:

# 使用scatter进行散点操作
result = target.scatter(0, index, values)

5. 得到结果:

tensor([[ 1,  6, 11, 16],
        [ 5, 10, 15,  4],
        [ 9, 14,  3,  8],
        [13,  2,  7, 12]])

解释:索引张量第一行第一列索引为0,那么将对应位置的源张量的值‘1’散布到目标张量的行索引为0的对应位置,因此目标向量第一行第一列为1

索引向量第二行第三列索引为3,对应位置源张量值为7,则将4散布到目标张量的行索引为3的对应位置,因此目标向量第四行第三列为7

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值