pytorch的scatter和scatter_add操作

本文简要总结了PyTorch中tensor的scatter和scatter_add操作。要求包括三者尺寸匹配,例如a.scatter(dim, index, src),其中a、index和src的维度数相等。操作分为两步:按index大小从src切片,然后根据dim值替换a中的相应位置。在实际应用中,发现除dim外的其他维度,a.size(i)必须等于index.size(i),否则会报错。" 53309759,5084701,使用Spark进行数据质量检查,"['Spark', '大数据处理', '数据质量检查', 'Python编程', '数据清洗']
摘要由CSDN通过智能技术生成

       tensor的scatter和scatter_add操作,这篇讲得比较详细,这里我就简单总结一下,以a.scatter(dim=dim,index=index,src=src)为例

  1. len(a.size())=len(index.size())=len(src.size())=dim_num 即三者维度数必须相等
  2. a与index的关系:a.size(i)≥index.size(i) i≠dim
    a.size(dim)与index.size(dim)不存在明确的大小关系。(此条对应原文中的约束3)
    比如a的size是(2,3,4),dim=1,index的size是(x,y,z),则x≤2,z≤4,y的取值无所谓,大于0就可以
    index中每个位置上的值的取值范围为[0,a.size(1)-1]
  3. index与src的关系:index.size(i)≤src.size(i),即src在每个维度上不小于index就可以

       其实感觉就是分两步:
       step1. 在src中从左上角开始,按index的size切片或切块(这也是为什么src的在每个维度上都要大于等于index的原因)
       step2. 用dim值替换:比如index[2,3,4]的值为0,src[2,3,4]的值为100,dim=1,那么就用100去替代a[2,0,4]的值;再比如,index[1,2,5]的值为3,src[1,2,5]的值为85,dim=2,那么就用85去替代a[1,2,3]的值,若dim=0,那么就用85去替代a[3,2,5]的值。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值