tensor的scatter和scatter_add操作,这篇讲得比较详细,这里我就简单总结一下,以a.scatter(dim=dim,index=index,src=src)为例
- len(a.size())=len(index.size())=len(src.size())=dim_num 即三者维度数必须相等
- a与index的关系:a.size(i)≥index.size(i) i≠dim
a.size(dim)与index.size(dim)不存在明确的大小关系。(此条对应原文中的约束3)
比如a的size是(2,3,4),dim=1,index的size是(x,y,z),则x≤2,z≤4,y的取值无所谓,大于0就可以
index中每个位置上的值的取值范围为[0,a.size(1)-1] - index与src的关系:index.size(i)≤src.size(i),即src在每个维度上不小于index就可以
其实感觉就是分两步:
step1. 在src中从左上角开始,按index的size切片或切块(这也是为什么src的在每个维度上都要大于等于index的原因)
step2. 用dim值替换:比如index[2,3,4]的值为0,src[2,3,4]的值为100,dim=1,那么就用100去替代a[2,0,4]的值;再比如,index[1,2,5]的值为3,src[1,2,5]的值为85,dim=2,那么就用85去替代a[1,2,3]的值,若dim=0,那么就用85去替代a[3,2,5]的值。

本文简要总结了PyTorch中tensor的scatter和scatter_add操作。要求包括三者尺寸匹配,例如a.scatter(dim, index, src),其中a、index和src的维度数相等。操作分为两步:按index大小从src切片,然后根据dim值替换a中的相应位置。在实际应用中,发现除dim外的其他维度,a.size(i)必须等于index.size(i),否则会报错。"
53309759,5084701,使用Spark进行数据质量检查,"['Spark', '大数据处理', '数据质量检查', 'Python编程', '数据清洗']
最低0.47元/天 解锁文章
4955

被折叠的 条评论
为什么被折叠?



